You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2017/07/15 04:15:35 UTC

[09/23] systemml git commit: Review comments, part 1

Review comments, part 1


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/b94557fd
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/b94557fd
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/b94557fd

Branch: refs/heads/master
Commit: b94557fd2c90c591179cdbf05a32242fadc36448
Parents: d88f867
Author: Dylan Hutchison <dh...@cs.washington.edu>
Authored: Sun Jun 11 00:35:52 2017 -0700
Committer: Dylan Hutchison <dh...@cs.washington.edu>
Committed: Sun Jun 18 17:43:37 2017 -0700

----------------------------------------------------------------------
 .../java/org/apache/sysml/hops/AggUnaryOp.java  |  21 +-
 .../org/apache/sysml/hops/OptimizerUtils.java   |   5 -
 .../sysml/hops/rewrite/HopRewriteUtils.java     |   3 +-
 .../sysml/hops/rewrite/ProgramRewriter.java     |   4 +-
 .../RewriteAlgebraicSimplificationDynamic.java  |  16 +-
 .../apache/sysml/hops/rewrite/RewriteEMult.java | 284 -------------------
 ...RewriteElementwiseMultChainOptimization.java | 281 ++++++++++++++++++
 .../functions/misc/RewriteEMultChainTest.java   | 127 ---------
 ...ementwiseMultChainOptimizationChainTest.java | 127 +++++++++
 .../ternary/ABATernaryAggregateTest.java        |   9 +-
 .../functions/misc/ZPackageSuite.java           |   1 +
 .../functions/ternary/ZPackageSuite.java        |   3 +-
 12 files changed, 436 insertions(+), 445 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/b94557fd/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
index 300a20c..a207831 100644
--- a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
+++ b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
@@ -497,7 +497,7 @@ public class AggUnaryOp extends Hop implements MultiThreadedHop
 				if (binput1.getOp() == OpOp2.POW
 						&& binput1.getInput().get(1) instanceof LiteralOp) {
 					LiteralOp lit = (LiteralOp)binput1.getInput().get(1);
-					ret = lit.getLongValue() == 3;
+					ret = HopRewriteUtils.getIntValueSafe(lit) == 3;
 				}
 				else if (binput1.getOp() == OpOp2.MULT
 						// As unary agg instruction is not implemented in MR and since MR is in maintenance mode, postponed it.
@@ -640,15 +640,10 @@ public class AggUnaryOp extends Hop implements MultiThreadedHop
 		boolean handled = false;
 
 		if (input1.getOp() == OpOp2.POW) {
-			switch ((int)((LiteralOp)input12).getLongValue()) {
-			case 3:
-				in1 = input11.constructLops();
-				in2 = in1;
-				in3 = in1;
-				break;
-			default:
-				throw new AssertionError("unreachable; only applies to power 3");
-			}
+			assert(HopRewriteUtils.isLiteralOfValue(input12, 3)) : "this case can only occur with a power of 3";
+			in1 = input11.constructLops();
+			in2 = in1;
+			in3 = in1;
 			handled = true;
 		} else if (input11 instanceof BinaryOp ) {
 			BinaryOp b11 = (BinaryOp)input11;
@@ -662,8 +657,7 @@ public class AggUnaryOp extends Hop implements MultiThreadedHop
 			case POW: // A*A*B case
 				Hop b112 = b11.getInput().get(1);
 				if ( !(input12 instanceof BinaryOp && ((BinaryOp)input12).getOp()==OpOp2.MULT)
-						&& b112 instanceof LiteralOp
-						&& ((LiteralOp)b112).getLongValue() == 2) {
+						&& HopRewriteUtils.isLiteralOfValue(b112, 2) ) {
 					in1 = b11.getInput().get(0).constructLops();
 					in2 = in1;
 					in3 = input12.constructLops();
@@ -682,8 +676,7 @@ public class AggUnaryOp extends Hop implements MultiThreadedHop
 				break;
 			case POW: // A*B*B case
 				Hop b112 = b12.getInput().get(1);
-				if ( b112 instanceof LiteralOp
-						&& ((LiteralOp)b112).getLongValue() == 2) {
+				if ( HopRewriteUtils.isLiteralOfValue(b112, 2) ) {
 					in1 = b12.getInput().get(0).constructLops();
 					in2 = in1;
 					in3 = input11.constructLops();

http://git-wip-us.apache.org/repos/asf/systemml/blob/b94557fd/src/main/java/org/apache/sysml/hops/OptimizerUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/OptimizerUtils.java b/src/main/java/org/apache/sysml/hops/OptimizerUtils.java
index 2a76d07..79b7ee6 100644
--- a/src/main/java/org/apache/sysml/hops/OptimizerUtils.java
+++ b/src/main/java/org/apache/sysml/hops/OptimizerUtils.java
@@ -111,11 +111,6 @@ public class OptimizerUtils
 	public static boolean ALLOW_CONSTANT_FOLDING = true;
 	
 	public static boolean ALLOW_ALGEBRAIC_SIMPLIFICATION = true;
-	/**
-	 * Enables rewriting chains of element-wise multiplies that contain the same multiplicand more than once, as in
-	 * `A*B*A ==> (A^2)*B`.
-	 */
-	public static boolean ALLOW_EMULT_CHAIN_REWRITE = true;
 	public static boolean ALLOW_OPERATOR_FUSION = true;
 	
 	/**

http://git-wip-us.apache.org/repos/asf/systemml/blob/b94557fd/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
index 17ac4ec..8f71359 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
@@ -251,8 +251,7 @@ public class HopRewriteUtils
 	 * @return replacement
 	 */
 	public static Hop replaceHop(final Hop old, final Hop replacement) {
-		final ArrayList<Hop> rootParents = old.getParent();
-		if (rootParents.isEmpty())
+		if (old.getParent().isEmpty())
 			return replacement; // new old!
 		HopRewriteUtils.rewireAllParentChildReferences(old, replacement);
 		return replacement;

http://git-wip-us.apache.org/repos/asf/systemml/blob/b94557fd/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
index b6aab38..1053850 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
@@ -96,8 +96,8 @@ public class ProgramRewriter
 			_dagRuleSet.add(     new RewriteRemoveUnnecessaryCasts()             );		
 			if( OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION )
 				_dagRuleSet.add( new RewriteCommonSubexpressionElimination()     );
-			if ( OptimizerUtils.ALLOW_EMULT_CHAIN_REWRITE )
-				_dagRuleSet.add( new RewriteEMult()                              ); //dependency: cse
+			if ( OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES)
+				_dagRuleSet.add( new RewriteElementwiseMultChainOptimization()                              ); //dependency: cse
 			if( OptimizerUtils.ALLOW_CONSTANT_FOLDING )
 				_dagRuleSet.add( new RewriteConstantFolding()                    ); //dependency: cse
 			if( OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION )

http://git-wip-us.apache.org/repos/asf/systemml/blob/b94557fd/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index 166af2f..91c5972 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -53,6 +53,8 @@ import org.apache.sysml.parser.DataExpression;
 import org.apache.sysml.parser.Expression.DataType;
 import org.apache.sysml.parser.Expression.ValueType;
 
+import static org.apache.sysml.hops.OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES;
+
 /**
  * Rule: Algebraic Simplifications. Simplifies binary expressions
  * in terms of two major purposes: (1) rewrite binary operations
@@ -2051,12 +2053,14 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
 					&& hi2.getInput().get(0).getDim2()==1 && hi2.getInput().get(1).getDim2()==1
 					&& !HopRewriteUtils.isBinary(hi2.getInput().get(0), OpOp2.MULT)
 					&& !HopRewriteUtils.isBinary(hi2.getInput().get(1), OpOp2.MULT)
-					&& !(HopRewriteUtils.isBinary(hi2.getInput().get(0), OpOp2.POW)      // do not rewrite (A^2)*B
-						&& hi2.getInput().get(0).getInput().get(1) instanceof LiteralOp  // let tak+* handle it
-						&& ((LiteralOp)hi2.getInput().get(0).getInput().get(1)).getLongValue() == 2)
-					&& !(HopRewriteUtils.isBinary(hi2.getInput().get(1), OpOp2.POW)      // do not rewrite B*(A^2)
-						&& hi2.getInput().get(1).getInput().get(1) instanceof LiteralOp  // let tak+* handle it
-						&& ((LiteralOp)hi2.getInput().get(1).getInput().get(1)).getLongValue() == 2)
+					&& ( !ALLOW_SUM_PRODUCT_REWRITES
+						|| !(  HopRewriteUtils.isBinary(hi2.getInput().get(0), OpOp2.POW)     // do not rewrite (A^2)*B
+							&& hi2.getInput().get(0).getInput().get(1) instanceof LiteralOp   // let tak+* handle it
+							&& ((LiteralOp)hi2.getInput().get(0).getInput().get(1)).getLongValue() == 2 ))
+					&& ( !ALLOW_SUM_PRODUCT_REWRITES
+						|| !( HopRewriteUtils.isBinary(hi2.getInput().get(1), OpOp2.POW)      // do not rewrite B*(A^2)
+							&& hi2.getInput().get(1).getInput().get(1) instanceof LiteralOp   // let tak+* handle it
+							&& ((LiteralOp)hi2.getInput().get(1).getInput().get(1)).getLongValue() == 2 ))
 					)
 			{
 				baLeft = hi2.getInput().get(0);

http://git-wip-us.apache.org/repos/asf/systemml/blob/b94557fd/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
deleted file mode 100644
index 5cd1471..0000000
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
+++ /dev/null
@@ -1,284 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysml.hops.rewrite;
-
-import java.util.ArrayList;
-import java.util.Comparator;
-import java.util.HashSet;
-import java.util.Iterator;
-import java.util.Map;
-import java.util.Set;
-import java.util.SortedMap;
-import java.util.TreeMap;
-
-import org.apache.sysml.hops.BinaryOp;
-import org.apache.sysml.hops.Hop;
-import org.apache.sysml.hops.HopsException;
-import org.apache.sysml.hops.LiteralOp;
-import org.apache.sysml.parser.Expression;
-
-import com.google.common.collect.HashMultiset;
-import com.google.common.collect.Multiset;
-
-/**
- * Prerequisite: RewriteCommonSubexpressionElimination must run before this rule.
- *
- * Rewrite a chain of element-wise multiply hops that contain identical elements.
- * For example `(B * A) * B` is rewritten to `A * (B^2)` (or `(B^2) * A`), where `^` is element-wise power.
- * The order of the multiplicands depends on their data types, dimentions (matrix or vector), and sparsity.
- *
- * Does not rewrite in the presence of foreign parents in the middle of the e-wise multiply chain,
- * since foreign parents may rely on the individual results.
- */
-public class RewriteEMult extends HopRewriteRule {
-	@Override
-	public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) throws HopsException {
-		if( roots == null )
-			return null;
-		for( int i=0; i<roots.size(); i++ ) {
-			Hop h = roots.get(i);
-			roots.set(i, rule_RewriteEMult(h));
-		}
-		return roots;
-	}
-
-	@Override
-	public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) throws HopsException {
-		if( root == null )
-			return null;
-		return rule_RewriteEMult(root);
-	}
-
-	private static boolean isBinaryMult(final Hop hop) {
-		return hop instanceof BinaryOp && ((BinaryOp)hop).getOp() == Hop.OpOp2.MULT;
-	}
-
-	private static Hop rule_RewriteEMult(final Hop root) {
-		if (root.isVisited())
-			return root;
-		root.setVisited();
-
-		// 1. Find immediate subtree of EMults.
-		if (isBinaryMult(root)) {
-			final Hop left = root.getInput().get(0), right = root.getInput().get(1);
-			final Set<BinaryOp> emults = new HashSet<>();
-			final Multiset<Hop> leaves = HashMultiset.create();
-			findEMultsAndLeaves((BinaryOp)root, emults, leaves);
-
-			// 2. Ensure it is profitable to do a rewrite.
-			if (isOptimizable(emults, leaves)) {
-				// 3. Check for foreign parents.
-				// A foreign parent is a parent of some EMult that is not in the set.
-				// Foreign parents destroy correctness of this rewrite.
-				final boolean okay = (!isBinaryMult(left) || checkForeignParent(emults, (BinaryOp)left)) &&
-						(!isBinaryMult(right) || checkForeignParent(emults, (BinaryOp)right));
-				if (okay) {
-					// 4. Construct replacement EMults for the leaves
-					final Hop replacement = constructReplacement(leaves);
-					if (LOG.isDebugEnabled())
-						LOG.debug(String.format(
-								"Element-wise multiply chain rewrite of %d e-mults at sub-dag %d to new sub-dag %d",
-								emults.size(), root.getHopID(), replacement.getHopID()));
-
-					// 5. Replace root with replacement
-					final Hop newRoot = HopRewriteUtils.replaceHop(root, replacement);
-
-					// 6. Recurse at leaves (no need to repeat the interior emults)
-					for (final Hop leaf : leaves.elementSet()) {
-						recurseInputs(leaf);
-					}
-					return newRoot;
-				}
-			}
-		}
-		// This rewrite is not applicable to the current root.
-		// Try the root's children.
-		recurseInputs(root);
-		return root;
-	}
-
-	private static void recurseInputs(final Hop parent) {
-		final ArrayList<Hop> inputs = parent.getInput();
-		for (int i = 0; i < inputs.size(); i++) {
-			final Hop input = inputs.get(i);
-			final Hop newInput = rule_RewriteEMult(input);
-			inputs.set(i, newInput);
-		}
-	}
-
-	private static Hop constructReplacement(final Multiset<Hop> leaves) {
-		// Sort by data type
-		final SortedMap<Hop,Integer> sorted = new TreeMap<>(compareByDataType);
-		for (final Multiset.Entry<Hop> entry : leaves.entrySet()) {
-			final Hop h = entry.getElement();
-			// unlink parents (the EMults, which we are throwing away)
-			h.getParent().clear();
-			sorted.put(h, entry.getCount());
-		}
-		// sorted contains all leaves, sorted by data type, stripped from their parents
-
-		// Construct right-deep EMult tree
-		final Iterator<Map.Entry<Hop, Integer>> iterator = sorted.entrySet().iterator();
-		Hop first = constructPower(iterator.next());
-
-		for (int i = 1; i < sorted.size(); i++) {
-			final Hop second = constructPower(iterator.next());
-			first = HopRewriteUtils.createBinary(second, first, Hop.OpOp2.MULT);
-			first.setVisited();
-		}
-		return first;
-	}
-
-	private static Hop constructPower(Map.Entry<Hop, Integer> entry) {
-		final Hop hop = entry.getKey();
-		final int cnt = entry.getValue();
-		assert(cnt >= 1);
-		if (cnt == 1)
-			return hop; // don't set this visited... we will visit this next
-		Hop pow = HopRewriteUtils.createBinary(hop, new LiteralOp(cnt), Hop.OpOp2.POW);
-		pow.setVisited();
-		return pow;
-	}
-
-	/**
-	 * A Comparator that orders Hops by their data type, dimention, and sparsity.
-	 * The order is as follows:
-	 * 		scalars > row vectors > col vectors >
-	 *      non-vector matrices ordered by sparsity (higher nnz first, unknown sparsity last) >
-	 *      other data types.
-	 * Disambiguate by Hop ID.
-	 */
-	private static final Comparator<Hop> compareByDataType = new Comparator<Hop>() {
-		@Override
-		public final int compare(Hop o1, Hop o2) {
-			int c = Integer.compare(orderDataType[o1.getDataType().ordinal()], orderDataType[o2.getDataType().ordinal()]);
-			if (c != 0) return c;
-
-			// o1 and o2 have the same data type
-			switch (o1.getDataType()) {
-			case SCALAR: return Long.compare(o1.getHopID(), o2.getHopID());
-			case MATRIX:
-				// two matrices; check for vectors
-				if (o1.getDim1() == 1) { // row vector
-						if (o2.getDim1() != 1) return 1; // row vectors are greatest of matrices
-						return compareBySparsityThenId(o1, o2); // both row vectors
-				} else if (o2.getDim1() == 1) { // 2 is row vector; 1 is not
-						return -1; // row vectors are the greatest matrices
-				} else if (o1.getDim2() == 1) { // col vector
-						if (o2.getDim2() != 1) return 1; // col vectors greater than non-vectors
-						return compareBySparsityThenId(o1, o2); // both col vectors
-				} else if (o2.getDim2() == 1) { // 2 is col vector; 1 is not
-						return 1; // col vectors greater than non-vectors
-				} else { // both non-vectors
-						return compareBySparsityThenId(o1, o2);
-				}
-			default:
-				return Long.compare(o1.getHopID(), o2.getHopID());
-			}
-		}
-		private int compareBySparsityThenId(Hop o1, Hop o2) {
-			// the hop with more nnz is first; unknown nnz (-1) last
-			int c = Long.compare(o1.getNnz(), o2.getNnz());
-			if (c != 0) return c;
-			return Long.compare(o1.getHopID(), o2.getHopID());
-		}
-		private final int[] orderDataType;
-		{
-			Expression.DataType[] dtValues = Expression.DataType.values();
-			orderDataType = new int[dtValues.length];
-			for (int i = 0, valuesLength = dtValues.length; i < valuesLength; i++) {
-				switch(dtValues[i]) {
-				case SCALAR:
-					orderDataType[i] = 4;
-					break;
-				case MATRIX:
-					orderDataType[i] = 3;
-					break;
-				case FRAME:
-					orderDataType[i] = 2;
-					break;
-				case OBJECT:
-					orderDataType[i] = 1;
-					break;
-				case UNKNOWN:
-					orderDataType[i] = 0;
-					break;
-				}
-			}
-		}
-	};
-
-	/**
-	 * Check if a node has a parent that is not in the set of emults. Recursively check children who are also emults.
-	 * @param emults The set of BinaryOp element-wise multiply hops in the emult chain.
-	 * @param child An interior emult hop in the emult chain dag.
-	 * @return Whether this interior emult or any child emult has a foreign parent.
-	 */
-	private static boolean checkForeignParent(final Set<BinaryOp> emults, final BinaryOp child) {
-		final ArrayList<Hop> parents = child.getParent();
-		if (parents.size() > 1)
-			for (final Hop parent : parents)
-				//noinspection SuspiciousMethodCalls
-				if (!emults.contains(parent))
-					return false;
-		// child does not have foreign parents
-
-		final ArrayList<Hop> inputs = child.getInput();
-		final Hop left = inputs.get(0), right = inputs.get(1);
-		return  (!isBinaryMult(left) || checkForeignParent(emults, (BinaryOp)left)) &&
-				(!isBinaryMult(right) || checkForeignParent(emults, (BinaryOp)right));
-	}
-
-	/**
-	 * Create a set of the counts of all BinaryOp MULTs in the immediate subtree, starting with root, recursively.
-	 * @param root Root of sub-dag
-	 * @param emults Out parameter. The set of BinaryOp element-wise multiply hops in the emult chain (including root).
-	 * @param leaves Out parameter. The multiset of multiplicands in the emult chain.
-	 */
-	private static void findEMultsAndLeaves(final BinaryOp root, final Set<BinaryOp> emults, final Multiset<Hop> leaves) {
-		// Because RewriteCommonSubexpressionElimination already ran, it is safe to compare by equality.
-		emults.add(root);
-
-		final ArrayList<Hop> inputs = root.getInput();
-		final Hop left = inputs.get(0), right = inputs.get(1);
-
-		if (isBinaryMult(left)) findEMultsAndLeaves((BinaryOp) left, emults, leaves);
-		else leaves.add(left);
-
-		if (isBinaryMult(right)) findEMultsAndLeaves((BinaryOp) right, emults, leaves);
-		else leaves.add(right);
-	}
-
-	/**
-	 * Only optimize a subtree of emults if there are at least two emults.
-	 * @param emults The set of BinaryOp element-wise multiply hops in the emult chain.
-	 * @param leaves The multiset of multiplicands in the emult chain.
-	 * @return If the multiset is worth optimizing.
-	 */
-	private static boolean isOptimizable(Set<BinaryOp> emults, final Multiset<Hop> leaves) {
-		// Old criterion: there should be at least one repeated leaf
-//		for (Multiset.Entry<Hop> hopEntry : leaves.entrySet()) {
-//			if (hopEntry.getCount() > 1)
-//				return true;
-//		}
-//		return false;
-		return emults.size() >= 2;
-	}
-}

http://git-wip-us.apache.org/repos/asf/systemml/blob/b94557fd/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
new file mode 100644
index 0000000..bd873ff
--- /dev/null
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
@@ -0,0 +1,281 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.hops.rewrite;
+
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.Map;
+import java.util.Set;
+import java.util.SortedMap;
+import java.util.TreeMap;
+
+import org.apache.sysml.hops.BinaryOp;
+import org.apache.sysml.hops.Hop;
+import org.apache.sysml.hops.HopsException;
+import org.apache.sysml.hops.LiteralOp;
+import org.apache.sysml.parser.Expression;
+
+import com.google.common.collect.HashMultiset;
+import com.google.common.collect.Multiset;
+
+/**
+ * Prerequisite: RewriteCommonSubexpressionElimination must run before this rule.
+ *
+ * Rewrite a chain of element-wise multiply hops that contain identical elements.
+ * For example `(B * A) * B` is rewritten to `A * (B^2)` (or `(B^2) * A`), where `^` is element-wise power.
+ * The order of the multiplicands depends on their data types, dimentions (matrix or vector), and sparsity.
+ *
+ * Does not rewrite in the presence of foreign parents in the middle of the e-wise multiply chain,
+ * since foreign parents may rely on the individual results.
+ */
+public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
+	@Override
+	public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) throws HopsException {
+		if( roots == null )
+			return null;
+		for( int i=0; i<roots.size(); i++ ) {
+			Hop h = roots.get(i);
+			roots.set(i, rule_RewriteEMult(h));
+		}
+		return roots;
+	}
+
+	@Override
+	public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) throws HopsException {
+		if( root == null )
+			return null;
+		return rule_RewriteEMult(root);
+	}
+
+	private static boolean isBinaryMult(final Hop hop) {
+		return hop instanceof BinaryOp && ((BinaryOp)hop).getOp() == Hop.OpOp2.MULT;
+	}
+
+	private static Hop rule_RewriteEMult(final Hop root) {
+		if (root.isVisited())
+			return root;
+		root.setVisited();
+
+		// 1. Find immediate subtree of EMults.
+		if (isBinaryMult(root)) {
+			final Hop left = root.getInput().get(0), right = root.getInput().get(1);
+			final Set<BinaryOp> emults = new HashSet<>();
+			final Multiset<Hop> leaves = HashMultiset.create();
+			findEMultsAndLeaves((BinaryOp)root, emults, leaves);
+
+			// 2. Ensure it is profitable to do a rewrite.
+			if (isOptimizable(emults, leaves)) {
+				// 3. Check for foreign parents.
+				// A foreign parent is a parent of some EMult that is not in the set.
+				// Foreign parents destroy correctness of this rewrite.
+				final boolean okay = (!isBinaryMult(left) || checkForeignParent(emults, (BinaryOp)left)) &&
+						(!isBinaryMult(right) || checkForeignParent(emults, (BinaryOp)right));
+				if (okay) {
+					// 4. Construct replacement EMults for the leaves
+					final Hop replacement = constructReplacement(leaves);
+					if (LOG.isDebugEnabled())
+						LOG.debug(String.format(
+								"Element-wise multiply chain rewrite of %d e-mults at sub-dag %d to new sub-dag %d",
+								emults.size(), root.getHopID(), replacement.getHopID()));
+
+					// 5. Replace root with replacement
+					final Hop newRoot = HopRewriteUtils.replaceHop(root, replacement);
+
+					// 6. Recurse at leaves (no need to repeat the interior emults)
+					for (final Hop leaf : leaves.elementSet()) {
+						recurseInputs(leaf);
+					}
+					return newRoot;
+				}
+			}
+		}
+		// This rewrite is not applicable to the current root.
+		// Try the root's children.
+		recurseInputs(root);
+		return root;
+	}
+
+	private static void recurseInputs(final Hop parent) {
+		final ArrayList<Hop> inputs = parent.getInput();
+		for (int i = 0; i < inputs.size(); i++) {
+			final Hop input = inputs.get(i);
+			final Hop newInput = rule_RewriteEMult(input);
+			inputs.set(i, newInput);
+		}
+	}
+
+	private static Hop constructReplacement(final Multiset<Hop> leaves) {
+		// Sort by data type
+		final SortedMap<Hop,Integer> sorted = new TreeMap<>(compareByDataType);
+		for (final Multiset.Entry<Hop> entry : leaves.entrySet()) {
+			final Hop h = entry.getElement();
+			// unlink parents (the EMults, which we are throwing away)
+			h.getParent().clear();
+			sorted.put(h, entry.getCount());
+		}
+		// sorted contains all leaves, sorted by data type, stripped from their parents
+
+		// Construct right-deep EMult tree
+		final Iterator<Map.Entry<Hop, Integer>> iterator = sorted.entrySet().iterator();
+		Hop first = constructPower(iterator.next());
+
+		for (int i = 1; i < sorted.size(); i++) {
+			final Hop second = constructPower(iterator.next());
+			first = HopRewriteUtils.createBinary(second, first, Hop.OpOp2.MULT);
+			first.setVisited();
+		}
+		return first;
+	}
+
+	private static Hop constructPower(Map.Entry<Hop, Integer> entry) {
+		final Hop hop = entry.getKey();
+		final int cnt = entry.getValue();
+		assert(cnt >= 1);
+		if (cnt == 1)
+			return hop; // don't set this visited... we will visit this next
+		Hop pow = HopRewriteUtils.createBinary(hop, new LiteralOp(cnt), Hop.OpOp2.POW);
+		pow.setVisited();
+		return pow;
+	}
+
+	/**
+	 * A Comparator that orders Hops by their data type, dimention, and sparsity.
+	 * The order is as follows:
+	 * 		scalars > row vectors > col vectors >
+	 *      non-vector matrices ordered by sparsity (higher nnz first, unknown sparsity last) >
+	 *      other data types.
+	 * Disambiguate by Hop ID.
+	 */
+	private static final Comparator<Hop> compareByDataType = new Comparator<Hop>() {
+		@Override
+		public final int compare(Hop o1, Hop o2) {
+			int c = Integer.compare(orderDataType[o1.getDataType().ordinal()], orderDataType[o2.getDataType().ordinal()]);
+			if (c != 0) return c;
+
+			// o1 and o2 have the same data type
+			switch (o1.getDataType()) {
+			case MATRIX:
+				// two matrices; check for vectors
+				if (o1.getDim1() == 1) { // row vector
+						if (o2.getDim1() != 1) return 1; // row vectors are greatest of matrices
+						return compareBySparsityThenId(o1, o2); // both row vectors
+				} else if (o2.getDim1() == 1) { // 2 is row vector; 1 is not
+						return -1; // row vectors are the greatest matrices
+				} else if (o1.getDim2() == 1) { // col vector
+						if (o2.getDim2() != 1) return 1; // col vectors greater than non-vectors
+						return compareBySparsityThenId(o1, o2); // both col vectors
+				} else if (o2.getDim2() == 1) { // 2 is col vector; 1 is not
+						return 1; // col vectors greater than non-vectors
+				} else { // both non-vectors
+						return compareBySparsityThenId(o1, o2);
+				}
+			default:
+				return Long.compare(o1.getHopID(), o2.getHopID());
+			}
+		}
+		private int compareBySparsityThenId(Hop o1, Hop o2) {
+			// the hop with more nnz is first; unknown nnz (-1) last
+			int c = Long.compare(o1.getNnz(), o2.getNnz());
+			if (c != 0) return c;
+			return Long.compare(o1.getHopID(), o2.getHopID());
+		}
+		private final int[] orderDataType;
+		{
+			Expression.DataType[] dtValues = Expression.DataType.values();
+			orderDataType = new int[dtValues.length];
+			for (int i = 0, valuesLength = dtValues.length; i < valuesLength; i++) {
+				switch(dtValues[i]) {
+				case SCALAR:
+					orderDataType[i] = 4;
+					break;
+				case MATRIX:
+					orderDataType[i] = 3;
+					break;
+				case FRAME:
+					orderDataType[i] = 2;
+					break;
+				case OBJECT:
+					orderDataType[i] = 1;
+					break;
+				case UNKNOWN:
+					orderDataType[i] = 0;
+					break;
+				}
+			}
+		}
+	};
+
+	/**
+	 * Check if a node has a parent that is not in the set of emults. Recursively check children who are also emults.
+	 * @param emults The set of BinaryOp element-wise multiply hops in the emult chain.
+	 * @param child An interior emult hop in the emult chain dag.
+	 * @return Whether this interior emult or any child emult has a foreign parent.
+	 */
+	private static boolean checkForeignParent(final Set<BinaryOp> emults, final BinaryOp child) {
+		final ArrayList<Hop> parents = child.getParent();
+		if (parents.size() > 1)
+			for (final Hop parent : parents)
+				//noinspection SuspiciousMethodCalls (for Intellij, which checks when
+				if (!emults.contains(parent))
+					return false;
+		// child does not have foreign parents
+
+		final ArrayList<Hop> inputs = child.getInput();
+		final Hop left = inputs.get(0), right = inputs.get(1);
+		return  (!isBinaryMult(left) || checkForeignParent(emults, (BinaryOp)left)) &&
+				(!isBinaryMult(right) || checkForeignParent(emults, (BinaryOp)right));
+	}
+
+	/**
+	 * Create a set of the counts of all BinaryOp MULTs in the immediate subtree, starting with root, recursively.
+	 * @param root Root of sub-dag
+	 * @param emults Out parameter. The set of BinaryOp element-wise multiply hops in the emult chain (including root).
+	 * @param leaves Out parameter. The multiset of multiplicands in the emult chain.
+	 */
+	private static void findEMultsAndLeaves(final BinaryOp root, final Set<BinaryOp> emults, final Multiset<Hop> leaves) {
+		// Because RewriteCommonSubexpressionElimination already ran, it is safe to compare by equality.
+		emults.add(root);
+
+		final ArrayList<Hop> inputs = root.getInput();
+		final Hop left = inputs.get(0), right = inputs.get(1);
+
+		if (isBinaryMult(left))
+			findEMultsAndLeaves((BinaryOp) left, emults, leaves);
+		else
+			leaves.add(left);
+
+		if (isBinaryMult(right))
+			findEMultsAndLeaves((BinaryOp) right, emults, leaves);
+		else
+			leaves.add(right);
+	}
+
+	/**
+	 * Only optimize a subtree of emults if there are at least two emults (meaning, at least 3 multiplicands).
+	 * @param emults The set of BinaryOp element-wise multiply hops in the emult chain.
+	 * @param leaves The multiset of multiplicands in the emult chain.
+	 * @return If the multiset is worth optimizing.
+	 */
+	private static boolean isOptimizable(Set<BinaryOp> emults, final Multiset<Hop> leaves) {
+		return emults.size() >= 2;
+	}
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/b94557fd/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java
deleted file mode 100644
index 85dbea4..0000000
--- a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java
+++ /dev/null
@@ -1,127 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- * 
- *   http://www.apache.org/licenses/LICENSE-2.0
- * 
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysml.test.integration.functions.misc;
-
-import java.util.HashMap;
-
-import org.apache.sysml.api.DMLScript;
-import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
-import org.apache.sysml.hops.OptimizerUtils;
-import org.apache.sysml.lops.LopProperties.ExecType;
-import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex;
-import org.apache.sysml.test.integration.AutomatedTestBase;
-import org.apache.sysml.test.integration.TestConfiguration;
-import org.apache.sysml.test.utils.TestUtils;
-import org.junit.Assert;
-import org.junit.Test;
-
-/**
- * Test whether `A*B*A` successfully rewrites to `(A^2)*B`.
- */
-public class RewriteEMultChainTest extends AutomatedTestBase
-{
-	private static final String TEST_NAME1 = "RewriteEMultChainOpXYX";
-	private static final String TEST_DIR = "functions/misc/";
-	private static final String TEST_CLASS_DIR = TEST_DIR + RewriteEMultChainTest.class.getSimpleName() + "/";
-	
-	private static final int rows = 123;
-	private static final int cols = 321;
-	private static final double eps = Math.pow(10, -10);
-	
-	@Override
-	public void setUp() {
-		TestUtils.clearAssertionInformation();
-		addTestConfiguration( TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) );
-	}
-
-	@Test
-	public void testMatrixMultChainOptNoRewritesCP() {
-		testRewriteMatrixMultChainOp(TEST_NAME1, false, ExecType.CP);
-	}
-	
-	@Test
-	public void testMatrixMultChainOptNoRewritesSP() {
-		testRewriteMatrixMultChainOp(TEST_NAME1, false, ExecType.SPARK);
-	}
-	
-	@Test
-	public void testMatrixMultChainOptRewritesCP() {
-		testRewriteMatrixMultChainOp(TEST_NAME1, true, ExecType.CP);
-	}
-	
-	@Test
-	public void testMatrixMultChainOptRewritesSP() {
-		testRewriteMatrixMultChainOp(TEST_NAME1, true, ExecType.SPARK);
-	}
-
-	private void testRewriteMatrixMultChainOp(String testname, boolean rewrites, ExecType et)
-	{	
-		RUNTIME_PLATFORM platformOld = rtplatform;
-		switch( et ){
-			case MR: rtplatform = RUNTIME_PLATFORM.HADOOP; break;
-			case SPARK: rtplatform = RUNTIME_PLATFORM.SPARK; break;
-			default: rtplatform = RUNTIME_PLATFORM.HYBRID_SPARK; break;
-		}
-		
-		boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
-		if( rtplatform == RUNTIME_PLATFORM.SPARK )
-			DMLScript.USE_LOCAL_SPARK_CONFIG = true;
-		
-		boolean rewritesOld = OptimizerUtils.ALLOW_EMULT_CHAIN_REWRITE;
-		OptimizerUtils.ALLOW_EMULT_CHAIN_REWRITE = rewrites;
-		
-		try
-		{
-			TestConfiguration config = getTestConfiguration(testname);
-			loadTestConfiguration(config);
-			
-			String HOME = SCRIPT_DIR + TEST_DIR;
-			fullDMLScriptName = HOME + testname + ".dml";
-			programArgs = new String[] { "-explain", "hops", "-stats", "-args", input("X"), input("Y"), output("R") };
-			fullRScriptName = HOME + testname + ".R";
-			rCmd = getRCmd(inputDir(), expectedDir());			
-
-			double Xsparsity = 0.8, Ysparsity = 0.6;
-			double[][] X = getRandomMatrix(rows, cols, -1, 1, Xsparsity, 7);
-			double[][] Y = getRandomMatrix(rows, cols, -1, 1, Ysparsity, 3);
-			writeInputMatrixWithMTD("X", X, true);
-			writeInputMatrixWithMTD("Y", Y, true);
-
-			//execute tests
-			runTest(true, false, null, -1); 
-			runRScript(true); 
-			
-			//compare matrices 
-			HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromHDFS("R");
-			HashMap<CellIndex, Double> rfile  = readRMatrixFromFS("R");
-			TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R");
-			
-			//check for presence of power operator, if we did a rewrite
-			if( rewrites ) {
-				Assert.assertTrue(heavyHittersContainsSubString("^2"));
-			}
-		}
-		finally {
-			OptimizerUtils.ALLOW_EMULT_CHAIN_REWRITE = rewritesOld;
-			rtplatform = platformOld;
-			DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
-		}
-	}
-}

http://git-wip-us.apache.org/repos/asf/systemml/blob/b94557fd/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationChainTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationChainTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationChainTest.java
new file mode 100644
index 0000000..47b2f0e
--- /dev/null
+++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationChainTest.java
@@ -0,0 +1,127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.test.integration.functions.misc;
+
+import java.util.HashMap;
+
+import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
+import org.apache.sysml.hops.OptimizerUtils;
+import org.apache.sysml.lops.LopProperties.ExecType;
+import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysml.test.integration.AutomatedTestBase;
+import org.apache.sysml.test.integration.TestConfiguration;
+import org.apache.sysml.test.utils.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+
+/**
+ * Test whether `A*B*A` successfully rewrites to `(A^2)*B`.
+ */
+public class RewriteElementwiseMultChainOptimizationChainTest extends AutomatedTestBase
+{
+	private static final String TEST_NAME1 = "RewriteEMultChainOpXYX";
+	private static final String TEST_DIR = "functions/misc/";
+	private static final String TEST_CLASS_DIR = TEST_DIR + RewriteElementwiseMultChainOptimizationChainTest.class.getSimpleName() + "/";
+	
+	private static final int rows = 123;
+	private static final int cols = 321;
+	private static final double eps = Math.pow(10, -10);
+	
+	@Override
+	public void setUp() {
+		TestUtils.clearAssertionInformation();
+		addTestConfiguration( TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) );
+	}
+
+	@Test
+	public void testMatrixMultChainOptNoRewritesCP() {
+		testRewriteMatrixMultChainOp(TEST_NAME1, false, ExecType.CP);
+	}
+	
+	@Test
+	public void testMatrixMultChainOptNoRewritesSP() {
+		testRewriteMatrixMultChainOp(TEST_NAME1, false, ExecType.SPARK);
+	}
+	
+	@Test
+	public void testMatrixMultChainOptRewritesCP() {
+		testRewriteMatrixMultChainOp(TEST_NAME1, true, ExecType.CP);
+	}
+	
+	@Test
+	public void testMatrixMultChainOptRewritesSP() {
+		testRewriteMatrixMultChainOp(TEST_NAME1, true, ExecType.SPARK);
+	}
+
+	private void testRewriteMatrixMultChainOp(String testname, boolean rewrites, ExecType et)
+	{	
+		RUNTIME_PLATFORM platformOld = rtplatform;
+		switch( et ){
+			case MR: rtplatform = RUNTIME_PLATFORM.HADOOP; break;
+			case SPARK: rtplatform = RUNTIME_PLATFORM.SPARK; break;
+			default: rtplatform = RUNTIME_PLATFORM.HYBRID_SPARK; break;
+		}
+		
+		boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+		if( rtplatform == RUNTIME_PLATFORM.SPARK )
+			DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+		
+		boolean rewritesOld = OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES;
+		OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewrites;
+		
+		try
+		{
+			TestConfiguration config = getTestConfiguration(testname);
+			loadTestConfiguration(config);
+			
+			String HOME = SCRIPT_DIR + TEST_DIR;
+			fullDMLScriptName = HOME + testname + ".dml";
+			programArgs = new String[] { "-explain", "hops", "-stats", "-args", input("X"), input("Y"), output("R") };
+			fullRScriptName = HOME + testname + ".R";
+			rCmd = getRCmd(inputDir(), expectedDir());			
+
+			double Xsparsity = 0.8, Ysparsity = 0.6;
+			double[][] X = getRandomMatrix(rows, cols, -1, 1, Xsparsity, 7);
+			double[][] Y = getRandomMatrix(rows, cols, -1, 1, Ysparsity, 3);
+			writeInputMatrixWithMTD("X", X, true);
+			writeInputMatrixWithMTD("Y", Y, true);
+
+			//execute tests
+			runTest(true, false, null, -1); 
+			runRScript(true); 
+			
+			//compare matrices 
+			HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromHDFS("R");
+			HashMap<CellIndex, Double> rfile  = readRMatrixFromFS("R");
+			TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R");
+			
+			//check for presence of power operator, if we did a rewrite
+			if( rewrites ) {
+				Assert.assertTrue(heavyHittersContainsSubString("^2"));
+			}
+		}
+		finally {
+			OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewritesOld;
+			rtplatform = platformOld;
+			DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+		}
+	}
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/b94557fd/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java b/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java
index 1829bf0..460829d 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java
@@ -24,6 +24,7 @@ import java.util.HashMap;
 import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
 import org.apache.sysml.hops.OptimizerUtils;
+import org.apache.sysml.hops.rewrite.RewriteElementwiseMultChainOptimization;
 import org.apache.sysml.lops.LopProperties.ExecType;
 import org.apache.sysml.runtime.instructions.Instruction;
 import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex;
@@ -36,7 +37,7 @@ import org.junit.Test;
 
 /**
  * Similar to {@link TernaryAggregateTest} except that it tests `sum(A*B*A)`.
- * Checks compatibility with {@link org.apache.sysml.hops.rewrite.RewriteEMult}.
+ * Checks compatibility with {@link RewriteElementwiseMultChainOptimization}.
  */
 public class ABATernaryAggregateTest extends AutomatedTestBase
 {
@@ -368,14 +369,14 @@ public class ABATernaryAggregateTest extends AutomatedTestBase
 			DMLScript.USE_LOCAL_SPARK_CONFIG = true;
 	
 		boolean rewritesOld = OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES,
-				rewritesOldEmult = OptimizerUtils.ALLOW_EMULT_CHAIN_REWRITE;
+				rewritesOldEmult = OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES;
 		
 		try {
 			TestConfiguration config = getTestConfiguration(testname);
 			loadTestConfiguration(config);
 			
 			OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewrites;
-			OptimizerUtils.ALLOW_EMULT_CHAIN_REWRITE = rewrites;
+			OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewrites;
 
 			String HOME = SCRIPT_DIR + TEST_DIR;
 			fullDMLScriptName = HOME + testname + ".dml";
@@ -411,7 +412,7 @@ public class ABATernaryAggregateTest extends AutomatedTestBase
 			rtplatform = platformOld;
 			DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
 			OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewritesOld;
-			OptimizerUtils.ALLOW_EMULT_CHAIN_REWRITE = rewritesOldEmult;
+			OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewritesOldEmult;
 		}
 	}
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/b94557fd/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
----------------------------------------------------------------------
diff --git a/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java b/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
index e352e6d..deea784 100644
--- a/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
+++ b/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
@@ -50,6 +50,7 @@ import org.junit.runners.Suite;
 	ReadAfterWriteTest.class,
 	RewriteCSETransposeScalarTest.class,
 	RewriteCTableToRExpandTest.class,
+	RewriteElementwiseMultChainOptimizationChainTest.class,
 	RewriteEliminateAggregatesTest.class,
 	RewriteFuseBinaryOpChainTest.class,
 	RewriteFusedRandTest.class,

http://git-wip-us.apache.org/repos/asf/systemml/blob/b94557fd/src/test_suites/java/org/apache/sysml/test/integration/functions/ternary/ZPackageSuite.java
----------------------------------------------------------------------
diff --git a/src/test_suites/java/org/apache/sysml/test/integration/functions/ternary/ZPackageSuite.java b/src/test_suites/java/org/apache/sysml/test/integration/functions/ternary/ZPackageSuite.java
index 784177d..ee14359 100644
--- a/src/test_suites/java/org/apache/sysml/test/integration/functions/ternary/ZPackageSuite.java
+++ b/src/test_suites/java/org/apache/sysml/test/integration/functions/ternary/ZPackageSuite.java
@@ -22,10 +22,11 @@ package org.apache.sysml.test.integration.functions.ternary;
 import org.junit.runner.RunWith;
 import org.junit.runners.Suite;
 
-/** Group together the tests in this package into a single suite so that the Maven build
+/* Group together the tests in this package into a single suite so that the Maven build
  *  won't run two of them at once. */
 @RunWith(Suite.class)
 @Suite.SuiteClasses({
+	ABATernaryAggregateTest.class,
 	CentralMomentWeightsTest.class,
 	CovarianceWeightsTest.class,
 	CTableMatrixIgnoreZerosTest.class,