You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2017/07/15 04:15:27 UTC

[01/23] systemml git commit: New rewrite rule for chains of element-wise multiply.

Repository: systemml
Updated Branches:
  refs/heads/master 1b3dff06b -> 85e3a9631


New rewrite rule for chains of element-wise multiply.

Placed rewrite rule after Common Subexpression Elimination.
Included helper method in HopRewriteUtils.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/7d578838
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/7d578838
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/7d578838

Branch: refs/heads/master
Commit: 7d578838cc291a1adb6229bae01f7c9428b6f858
Parents: c434208
Author: Dylan Hutchison <dh...@cs.washington.edu>
Authored: Thu Jun 8 18:17:36 2017 -0700
Committer: Dylan Hutchison <dh...@cs.washington.edu>
Committed: Sun Jun 18 17:43:13 2017 -0700

----------------------------------------------------------------------
 .../sysml/hops/rewrite/HopRewriteUtils.java     |  17 +-
 .../sysml/hops/rewrite/ProgramRewriter.java     |   1 +
 .../apache/sysml/hops/rewrite/RewriteEMult.java | 186 +++++++++++++++++++
 .../org/apache/sysml/parser/Expression.java     |   1 +
 4 files changed, 204 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/7d578838/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
index cf6081b..4d23cb9 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
@@ -241,7 +241,22 @@ public class HopRewriteUtils
 		parent.getInput().add( pos, child );
 		child.getParent().add( parent );
 	}
-	
+
+	/**
+	 * Replace an old Hop with a replacement Hop.
+	 * If the old Hop has no parents, then return the replacement.
+	 * Otherwise rewire each of the Hop's parents into the replacement and return the replacement.
+	 * @return replacement
+	 */
+	public static Hop replaceHop(final Hop old, final Hop replacement) {
+		final ArrayList<Hop> rootParents = old.getParent();
+		if (rootParents.isEmpty())
+			return replacement; // new old!
+		HopRewriteUtils.rewireAllParentChildReferences(old, replacement);
+		return replacement;
+	}
+
+
 	public static void rewireAllParentChildReferences( Hop hold, Hop hnew ) {
 		ArrayList<Hop> parents = new ArrayList<Hop>(hold.getParent());
 		for( Hop lparent : parents )

http://git-wip-us.apache.org/repos/asf/systemml/blob/7d578838/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
index 0e65f3f..8573dd7 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
@@ -96,6 +96,7 @@ public class ProgramRewriter
 			_dagRuleSet.add(     new RewriteRemoveUnnecessaryCasts()             );		
 			if( OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION )
 				_dagRuleSet.add( new RewriteCommonSubexpressionElimination()     );
+			_dagRuleSet.add( new RewriteEMult()                                  ); //dependency: cse
 			if( OptimizerUtils.ALLOW_CONSTANT_FOLDING )
 				_dagRuleSet.add( new RewriteConstantFolding()                    ); //dependency: cse
 			if( OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION )

http://git-wip-us.apache.org/repos/asf/systemml/blob/7d578838/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
new file mode 100644
index 0000000..47c32a9
--- /dev/null
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
@@ -0,0 +1,186 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.hops.rewrite;
+
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.Map;
+import java.util.Set;
+import java.util.SortedMap;
+import java.util.TreeMap;
+
+import org.apache.sysml.hops.BinaryOp;
+import org.apache.sysml.hops.Hop;
+import org.apache.sysml.hops.HopsException;
+import org.apache.sysml.hops.LiteralOp;
+
+import com.google.common.collect.HashMultiset;
+import com.google.common.collect.Multiset;
+
+/**
+ * Prerequisite: RewriteCommonSubexpressionElimination must run before this rule.
+ *
+ * Rewrite a chain of element-wise multiply hops that contain identical elements.
+ * For example `(B * A) * B` is rewritten to `A * (B^2)` (or `(B^2) * A`), where `^` is element-wise power.
+ *
+ * Does not rewrite in the presence of foreign parents in the middle of the e-wise multiply chain,
+ * since foreign parents may rely on the individual results.
+ */
+public class RewriteEMult extends HopRewriteRule {
+	@Override
+	public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) throws HopsException {
+		if( roots == null )
+			return null;
+
+		for( int i=0; i<roots.size(); i++ ) {
+			Hop h = roots.get(i);
+			roots.set(i, rule_RewriteEMult(h));
+		}
+		return roots;
+	}
+
+	@Override
+	public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) throws HopsException {
+		if( root == null )
+			return null;
+		return rule_RewriteEMult(root);
+	}
+
+	private static boolean isBinaryMult(final Hop hop) {
+		return hop instanceof BinaryOp && ((BinaryOp)hop).getOp() == Hop.OpOp2.MULT;
+	}
+
+	private static Hop rule_RewriteEMult(final Hop root) {
+		if (root.isVisited())
+			return root;
+		root.setVisited();
+
+		final ArrayList<Hop> rootInputs = root.getInput();
+
+		// 1. Find immediate subtree of EMults.
+		if (isBinaryMult(root)) {
+			final Hop left = rootInputs.get(0), right = rootInputs.get(1);
+			final BinaryOp r = (BinaryOp)root;
+			final Set<BinaryOp> emults = new HashSet<>();
+			final Multiset<Hop> leaves = HashMultiset.create();
+			findEMultsAndLeaves(r, emults, leaves);
+			// 2. Ensure it is profitable to do a rewrite.
+			if (isOptimizable(leaves)) {
+				// 3. Check for foreign parents.
+				// A foreign parent is a parent of some EMult that is not in the set.
+				// Foreign parents destroy correctness of this rewrite.
+				final boolean okay = (!isBinaryMult(left) || checkForeignParent(emults, (BinaryOp)left)) &&
+						(!isBinaryMult(right) || checkForeignParent(emults, (BinaryOp)right));
+				if (okay) {
+					// 4. Construct replacement EMults for the leaves
+					final Hop replacement = constructReplacement(leaves);
+
+					// 5. Replace root with replacement
+					return HopRewriteUtils.replaceHop(root, replacement);
+				}
+			}
+		}
+
+		// This rewrite is not applicable to the current root.
+		// Try the root's children.
+		for (int i = 0; i < rootInputs.size(); i++) {
+			final Hop input = rootInputs.get(i);
+			final Hop newInput = rule_RewriteEMult(input);
+			rootInputs.set(i, newInput);
+		}
+		return root;
+	}
+
+	private static Hop constructReplacement(final Multiset<Hop> leaves) {
+		// Sort by data type
+		final SortedMap<Hop,Integer> sorted = new TreeMap<>(compareByDataType);
+		for (final Multiset.Entry<Hop> entry : leaves.entrySet()) {
+			final Hop h = entry.getElement();
+			// unlink parents (the EMults, which we are throwing away)
+			h.getParent().clear();
+			sorted.put(h, entry.getCount());
+		}
+		// sorted contains all leaves, sorted by data type, stripped from their parents
+
+		// Construct left-deep EMult tree
+		Iterator<Map.Entry<Hop, Integer>> iterator = sorted.entrySet().iterator();
+		Hop first = constructPower(iterator.next());
+
+		for (int i = 1; i < sorted.size(); i++) {
+			final Hop second = constructPower(iterator.next());
+			first = HopRewriteUtils.createBinary(first, second, Hop.OpOp2.MULT);
+		}
+		return first;
+	}
+
+	private static Hop constructPower(Map.Entry<Hop, Integer> entry) {
+		final Hop hop = entry.getKey();
+		final int cnt = entry.getValue();
+		assert(cnt >= 1);
+		if (cnt == 1)
+			return hop;
+		return HopRewriteUtils.createBinary(hop, new LiteralOp(cnt), Hop.OpOp2.POW);
+	}
+
+	private static Comparator<Hop> compareByDataType = Comparator.comparing(Hop::getDataType);
+
+	private static boolean checkForeignParent(final Set<BinaryOp> emults, final BinaryOp child) {
+		final ArrayList<Hop> parents = child.getParent();
+		if (parents.size() > 1)
+			for (final Hop parent : parents)
+				//noinspection SuspiciousMethodCalls
+				if (!emults.contains(parent))
+					return false;
+		// child does not have foreign parents
+
+		final ArrayList<Hop> inputs = child.getInput();
+		final Hop left = inputs.get(0), right = inputs.get(1);
+		return  (!isBinaryMult(left) || checkForeignParent(emults, (BinaryOp)left)) &&
+				(!isBinaryMult(right) || checkForeignParent(emults, (BinaryOp)right));
+	}
+
+	/**
+	 * Create a set of the counts of all BinaryOp MULTs in the immediate subtree, starting with root.
+	 */
+	private static void findEMultsAndLeaves(final BinaryOp root, final Set<BinaryOp> emults, final Multiset<Hop> leaves) {
+		// Because RewriteCommonSubexpressionElimination already ran, it is safe to compare by equality.
+		emults.add(root);
+
+		final ArrayList<Hop> inputs = root.getInput();
+		final Hop left = inputs.get(0), right = inputs.get(1);
+
+		if (isBinaryMult(left)) findEMultsAndLeaves((BinaryOp) left, emults, leaves);
+		else leaves.add(left);
+
+		if (isBinaryMult(right)) findEMultsAndLeaves((BinaryOp) right, emults, leaves);
+		else leaves.add(right);
+	}
+
+	/** Only optimize a subtree of EMults if at least one leaf occurs more than once. */
+	private static boolean isOptimizable(final Multiset<Hop> set) {
+		for (Multiset.Entry<Hop> hopEntry : set.entrySet()) {
+			if (hopEntry.getCount() > 1)
+				return true;
+		}
+		return false;
+	}
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/7d578838/src/main/java/org/apache/sysml/parser/Expression.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/Expression.java b/src/main/java/org/apache/sysml/parser/Expression.java
index 9ee3fba..b944e29 100644
--- a/src/main/java/org/apache/sysml/parser/Expression.java
+++ b/src/main/java/org/apache/sysml/parser/Expression.java
@@ -162,6 +162,7 @@ public abstract class Expression
 	 * Data types (matrix, scalar, frame, object, unknown).
 	 */
 	public enum DataType {
+		// Careful: the order of these enums is significant! See RewriteEMult.comparatorByDataType
 		MATRIX, SCALAR, FRAME, OBJECT, UNKNOWN;
 		
 		public boolean isMatrix() {


[12/23] systemml git commit: Fix visit status bug

Posted by mb...@apache.org.
Fix visit status bug


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/0a8936cd
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/0a8936cd
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/0a8936cd

Branch: refs/heads/master
Commit: 0a8936cd849d74baced732f45f1c53812abce537
Parents: d6d3795
Author: Dylan Hutchison <dh...@cs.washington.edu>
Authored: Sun Jun 11 03:55:25 2017 -0700
Committer: Dylan Hutchison <dh...@cs.washington.edu>
Committed: Sun Jun 18 17:43:48 2017 -0700

----------------------------------------------------------------------
 .../apache/sysml/hops/rewrite/HopDagValidator.java |  5 ++++-
 .../RewriteElementwiseMultChainOptimization.java   | 17 +++++++++--------
 2 files changed, 13 insertions(+), 9 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/0a8936cd/src/main/java/org/apache/sysml/hops/rewrite/HopDagValidator.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopDagValidator.java b/src/main/java/org/apache/sysml/hops/rewrite/HopDagValidator.java
index 8cb5e1e..9ac21fc 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/HopDagValidator.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/HopDagValidator.java
@@ -35,6 +35,8 @@ import org.apache.sysml.parser.Expression;
 import org.apache.sysml.runtime.DMLRuntimeException;
 import org.apache.sysml.utils.Explain;
 
+import com.google.common.collect.Lists;
+
 import static org.apache.sysml.hops.HopsException.check;
 
 /**
@@ -89,7 +91,8 @@ public class HopDagValidator {
 		//check visit status
 		final boolean seen = !state.seen.add(id);
 		check(seen == hop.isVisited(), hop,
-				"seen previously is %b but does not match hop visit status", seen);
+				"(parents: %s) seen previously is %b but does not match hop visit status",
+				Lists.transform(hop.getParent(), Hop::getHopID), seen);
 		if (seen) return; // we saw the Hop previously, no need to re-validate
 		
 		//check parent linking

http://git-wip-us.apache.org/repos/asf/systemml/blob/0a8936cd/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
index 91b7306..9ca0932 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
@@ -91,7 +91,7 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
 						(!isBinaryMult(right) || checkForeignParent(emults, (BinaryOp)right));
 				if (okay) {
 					// 4. Construct replacement EMults for the leaves
-					final Hop replacement = constructReplacement(leaves);
+					final Hop replacement = constructReplacement(emults, leaves);
 					if (LOG.isDebugEnabled())
 						LOG.debug(String.format(
 								"Element-wise multiply chain rewrite of %d e-mults at sub-dag %d to new sub-dag %d",
@@ -123,13 +123,14 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
 		}
 	}
 
-	private static Hop constructReplacement(final Multiset<Hop> leaves) {
+	private static Hop constructReplacement(final Set<BinaryOp> emults, final Multiset<Hop> leaves) {
 		// Sort by data type
 		final SortedMap<Hop,Integer> sorted = new TreeMap<>(compareByDataType);
 		for (final Multiset.Entry<Hop> entry : leaves.entrySet()) {
 			final Hop h = entry.getElement();
-			// unlink parents (the EMults, which we are throwing away)
-			h.getParent().clear();
+			// unlink parents that are in the emult set(we are throwing them away)
+			// keep other parents
+			h.getParent().removeIf(parent -> parent instanceof BinaryOp && emults.contains(parent));
 			sorted.put(h, entry.getCount());
 		}
 		// sorted contains all leaves, sorted by data type, stripped from their parents
@@ -146,12 +147,13 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
 		return first;
 	}
 
-	private static Hop constructPower(Map.Entry<Hop, Integer> entry) {
+	private static Hop constructPower(final Map.Entry<Hop, Integer> entry) {
 		final Hop hop = entry.getKey();
 		final int cnt = entry.getValue();
 		assert(cnt >= 1);
+		hop.setVisited(); // we will visit the leaves' children next
 		if (cnt == 1)
-			return hop; // don't set this visited... we will visit this next
+			return hop;
 		Hop pow = HopRewriteUtils.createBinary(hop, new LiteralOp(cnt), Hop.OpOp2.POW);
 		pow.setVisited();
 		return pow;
@@ -222,8 +224,7 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
 		final ArrayList<Hop> parents = child.getParent();
 		if (parents.size() > 1)
 			for (final Hop parent : parents)
-				//noinspection SuspiciousMethodCalls (for Intellij, which checks when
-				if (!emults.contains(parent))
+				if (parent instanceof BinaryOp && !emults.contains(parent))
 					return false;
 		// child does not have foreign parents
 


[13/23] systemml git commit: Relax tolerance for ElementwiseAdditionMultiplicationTest

Posted by mb...@apache.org.
Relax tolerance for ElementwiseAdditionMultiplicationTest

The new RewriteElementwiseMultChainOptimization reorders
`(A*B)*C` to `A*(B*C)`, which causes the result not to be exact.
 Use epsilon of 1e-10.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/de469d23
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/de469d23
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/de469d23

Branch: refs/heads/master
Commit: de469d235e5fe06fd3e13a32262d6c357ffdcc81
Parents: 0a8936c
Author: Dylan Hutchison <dh...@cs.washington.edu>
Authored: Sun Jun 11 12:50:39 2017 -0700
Committer: Dylan Hutchison <dh...@cs.washington.edu>
Committed: Sun Jun 18 17:43:51 2017 -0700

----------------------------------------------------------------------
 .../binary/matrix/ElementwiseAdditionMultiplicationTest.java       | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/de469d23/src/test/java/org/apache/sysml/test/integration/functions/binary/matrix/ElementwiseAdditionMultiplicationTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/binary/matrix/ElementwiseAdditionMultiplicationTest.java b/src/test/java/org/apache/sysml/test/integration/functions/binary/matrix/ElementwiseAdditionMultiplicationTest.java
index 523a648..f78e598 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/binary/matrix/ElementwiseAdditionMultiplicationTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/binary/matrix/ElementwiseAdditionMultiplicationTest.java
@@ -134,6 +134,6 @@ public class ElementwiseAdditionMultiplicationTest extends AutomatedTestBase
 
 		runTest();
 
-		compareResults();
+		compareResults(1e-10);
 	}
 }


[20/23] systemml git commit: Minor optimization inside rewrite rule.

Posted by mb...@apache.org.
Minor optimization inside rewrite rule.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/e93c487e
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/e93c487e
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/e93c487e

Branch: refs/heads/master
Commit: e93c487ef1778934c94fac291c6e76651041c961
Parents: d18a4c8
Author: Dylan Hutchison <dh...@cs.washington.edu>
Authored: Tue Jul 11 23:04:21 2017 -0700
Committer: Dylan Hutchison <dh...@cs.washington.edu>
Committed: Tue Jul 11 23:04:21 2017 -0700

----------------------------------------------------------------------
 ...RewriteElementwiseMultChainOptimization.java | 38 ++++++++------------
 1 file changed, 15 insertions(+), 23 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/e93c487e/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
index 1f85bbf..41fc61d 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
@@ -26,8 +26,8 @@ import java.util.HashSet;
 import java.util.Iterator;
 import java.util.Map;
 import java.util.Set;
-import java.util.SortedMap;
-import java.util.TreeMap;
+import java.util.SortedSet;
+import java.util.TreeSet;
 
 import org.apache.sysml.hops.BinaryOp;
 import org.apache.sysml.hops.Hop;
@@ -85,12 +85,15 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
 		// 1. Find immediate subtree of EMults. Check dimsKnown.
 		if (isBinaryMult(root) && root.dimsKnown()) {
 			final Hop left = root.getInput().get(0), right = root.getInput().get(1);
+			// The set of BinaryOp element-wise multiply hops in the emult chain.
 			final Set<BinaryOp> emults = new HashSet<>();
+			// The multiset of multiplicands in the emult chain.
 			final Map<Hop, Integer> leaves = new HashMap<>(); // poor man's HashMultiset
 			findEMultsAndLeaves((BinaryOp)root, emults, leaves);
 
 			// 2. Ensure it is profitable to do a rewrite.
-			if (isOptimizable(emults, leaves)) {
+			// Only optimize a subtree of emults if there are at least two emults (meaning, at least 3 multiplicands).
+			if (emults.size() >= 2) {
 				// 3. Check for foreign parents.
 				// A foreign parent is a parent of some EMult that is not in the set.
 				// Foreign parents destroy correctness of this rewrite.
@@ -132,20 +135,20 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
 
 	private static Hop constructReplacement(final Set<BinaryOp> emults, final Map<Hop, Integer> leaves) {
 		// Sort by data type
-		final SortedMap<Hop,Integer> sorted = new TreeMap<>(compareByDataType);
+		final SortedSet<Hop> sorted = new TreeSet<>(compareByDataType);
 		for (final Map.Entry<Hop, Integer> entry : leaves.entrySet()) {
 			final Hop h = entry.getKey();
 			// unlink parents that are in the emult set(we are throwing them away)
 			// keep other parents
 			h.getParent().removeIf(parent -> parent instanceof BinaryOp && emults.contains(parent));
-			sorted.put(h, entry.getValue());
+			sorted.add(constructPower(h, entry.getValue()));
 		}
 		// sorted contains all leaves, sorted by data type, stripped from their parents
 
 		// Construct right-deep EMult tree
-		final Iterator<Map.Entry<Hop, Integer>> iterator = sorted.entrySet().iterator();
+		final Iterator<Hop> iterator = sorted.iterator();
 
-		Hop next = iterator.hasNext() ? constructPower(iterator.next()) : null;
+		Hop next = iterator.hasNext() ? iterator.next() : null;
 		Hop colVectorsScalars = null;
 		while(next != null &&
 				(next.getDataType() == Expression.DataType.SCALAR
@@ -157,7 +160,7 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
 				colVectorsScalars = HopRewriteUtils.createBinary(next, colVectorsScalars, Hop.OpOp2.MULT);
 				colVectorsScalars.setVisited();
 			}
-			next = iterator.hasNext() ? constructPower(iterator.next()) : null;
+			next = iterator.hasNext() ? iterator.next() : null;
 		}
 		// next is not processed and is either null or past col vectors
 
@@ -171,7 +174,7 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
 				rowVectors = HopRewriteUtils.createBinary(rowVectors, next, Hop.OpOp2.MULT);
 				rowVectors.setVisited();
 			}
-			next = iterator.hasNext() ? constructPower(iterator.next()) : null;
+			next = iterator.hasNext() ? iterator.next() : null;
 		}
 		// next is not processed and is either null or past row vectors
 
@@ -185,7 +188,7 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
 				matrices = HopRewriteUtils.createBinary(matrices, next, Hop.OpOp2.MULT);
 				matrices.setVisited();
 			}
-			next = iterator.hasNext() ? constructPower(iterator.next()) : null;
+			next = iterator.hasNext() ? iterator.next() : null;
 		}
 		// next is not processed and is either null or past matrices
 
@@ -198,7 +201,7 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
 				other = HopRewriteUtils.createBinary(other, next, Hop.OpOp2.MULT);
 				other.setVisited();
 			}
-			next = iterator.hasNext() ? constructPower(iterator.next()) : null;
+			next = iterator.hasNext() ? iterator.next() : null;
 		}
 		// finished
 
@@ -230,9 +233,7 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
 		return top;
 	}
 
-	private static Hop constructPower(final Map.Entry<Hop, Integer> entry) {
-		final Hop hop = entry.getKey();
-		final int cnt = entry.getValue();
+	private static Hop constructPower(final Hop hop, final int cnt) {
 		assert(cnt >= 1);
 		hop.setVisited(); // we will visit the leaves' children next
 		if (cnt == 1)
@@ -345,13 +346,4 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
 		map.put(k, map.getOrDefault(k, 0) + 1);
 	}
 
-	/**
-	 * Only optimize a subtree of emults if there are at least two emults (meaning, at least 3 multiplicands).
-	 * @param emults The set of BinaryOp element-wise multiply hops in the emult chain.
-	 * @param leaves The multiset of multiplicands in the emult chain.
-	 * @return If the multiset is worth optimizing.
-	 */
-	private static boolean isOptimizable(Set<BinaryOp> emults, final Map<Hop, Integer> leaves) {
-		return emults.size() >= 2;
-	}
 }


[16/23] systemml git commit: Group together types of e-wise multiply inputs. Comprehensive test.

Posted by mb...@apache.org.
Group together types of e-wise multiply inputs. Comprehensive test.

The test tests all different kinds of objects multiplied togehter.
The new order of element-wise multiply chains is as follows:

<pre>
    (((unknown * object * frame) * ([least-nnz-matrix * matrix] * most-nnz-matrix))
     * ([least-nnz-row-vector * row-vector] * most-nnz-row-vector))
    * ([[scalars * least-nnz-col-vector] * col-vector] * most-nnz-col-vector)
</pre>


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/999fdfbc
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/999fdfbc
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/999fdfbc

Branch: refs/heads/master
Commit: 999fdfbca9ebd855e031e4b812b64f1b484a33d8
Parents: 6c3e1c5
Author: Dylan Hutchison <dh...@cs.washington.edu>
Authored: Tue Jul 11 22:10:17 2017 -0700
Committer: Dylan Hutchison <dh...@cs.washington.edu>
Committed: Tue Jul 11 22:10:17 2017 -0700

----------------------------------------------------------------------
 ...RewriteElementwiseMultChainOptimization.java | 122 ++++++++++++++---
 ...ElementwiseMultChainOptimizationAllTest.java | 134 +++++++++++++++++++
 .../functions/misc/RewriteEMultChainOpAll.R     |  37 +++++
 .../functions/misc/RewriteEMultChainOpAll.dml   |  31 +++++
 4 files changed, 305 insertions(+), 19 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/999fdfbc/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
index 9cc8fcd..486072b 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
@@ -46,6 +46,14 @@ import com.google.common.collect.Multiset;
  *
  * Does not rewrite in the presence of foreign parents in the middle of the e-wise multiply chain,
  * since foreign parents may rely on the individual results.
+ *
+ * The new order of element-wise multiply chains is as follows:
+ * <pre>
+ *     (((unknown * object * frame) * ([least-nnz-matrix * matrix] * most-nnz-matrix))
+ *      * ([least-nnz-row-vector * row-vector] * most-nnz-row-vector))
+ *     * ([[scalars * least-nnz-col-vector] * col-vector] * most-nnz-col-vector)
+ * </pre>
+ * Identical elements are replaced with powers.
  */
 public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
 	@Override
@@ -137,14 +145,90 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
 
 		// Construct right-deep EMult tree
 		final Iterator<Map.Entry<Hop, Integer>> iterator = sorted.entrySet().iterator();
-		Hop first = constructPower(iterator.next());
 
-		for (int i = 1; i < sorted.size(); i++) {
-			final Hop second = constructPower(iterator.next());
-			first = HopRewriteUtils.createBinary(second, first, Hop.OpOp2.MULT);
-			first.setVisited();
+		Hop next = iterator.hasNext() ? constructPower(iterator.next()) : null;
+		Hop colVectorsScalars = null;
+		while(next != null &&
+				(next.getDataType() == Expression.DataType.SCALAR
+						|| next.getDataType() == Expression.DataType.MATRIX && next.getDim2() == 1))
+		{
+			if( colVectorsScalars == null )
+				colVectorsScalars = next;
+			else {
+				colVectorsScalars = HopRewriteUtils.createBinary(next, colVectorsScalars, Hop.OpOp2.MULT);
+				colVectorsScalars.setVisited();
+			}
+			next = iterator.hasNext() ? constructPower(iterator.next()) : null;
+		}
+		// next is not processed and is either null or past col vectors
+
+		Hop rowVectors = null;
+		while(next != null &&
+				(next.getDataType() == Expression.DataType.MATRIX && next.getDim1() == 1))
+		{
+			if( rowVectors == null )
+				rowVectors = next;
+			else {
+				rowVectors = HopRewriteUtils.createBinary(rowVectors, next, Hop.OpOp2.MULT);
+				rowVectors.setVisited();
+			}
+			next = iterator.hasNext() ? constructPower(iterator.next()) : null;
+		}
+		// next is not processed and is either null or past row vectors
+
+		Hop matrices = null;
+		while(next != null &&
+				(next.getDataType() == Expression.DataType.MATRIX))
+		{
+			if( matrices == null )
+				matrices = next;
+			else {
+				matrices = HopRewriteUtils.createBinary(matrices, next, Hop.OpOp2.MULT);
+				matrices.setVisited();
+			}
+			next = iterator.hasNext() ? constructPower(iterator.next()) : null;
 		}
-		return first;
+		// next is not processed and is either null or past matrices
+
+		Hop other = null;
+		while(next != null)
+		{
+			if( other == null )
+				other = next;
+			else {
+				other = HopRewriteUtils.createBinary(other, next, Hop.OpOp2.MULT);
+				other.setVisited();
+			}
+			next = iterator.hasNext() ? constructPower(iterator.next()) : null;
+		}
+		// finished
+
+		// ((other * matrices) * rowVectors) * colVectorsScalars
+		Hop top = null;
+		if( other == null && matrices != null )
+			top = matrices;
+		else if( other != null && matrices == null )
+			top = other;
+		else if( other != null ) { //matrices != null
+			top = HopRewriteUtils.createBinary(other, matrices, Hop.OpOp2.MULT);
+			top.setVisited();
+		}
+
+		if( top == null && rowVectors != null )
+			top = rowVectors;
+		else if( rowVectors != null ) { //top != null
+			top = HopRewriteUtils.createBinary(top, rowVectors, Hop.OpOp2.MULT);
+			top.setVisited();
+		}
+
+		if( top == null && colVectorsScalars != null )
+			top = colVectorsScalars;
+		else if( colVectorsScalars != null ) { //top != null
+			top = HopRewriteUtils.createBinary(top, colVectorsScalars, Hop.OpOp2.MULT);
+			top.setVisited();
+		}
+
+		return top;
 	}
 
 	private static Hop constructPower(final Map.Entry<Hop, Integer> entry) {
@@ -154,7 +238,7 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
 		hop.setVisited(); // we will visit the leaves' children next
 		if (cnt == 1)
 			return hop;
-		Hop pow = HopRewriteUtils.createBinary(hop, new LiteralOp(cnt), Hop.OpOp2.POW);
+		final Hop pow = HopRewriteUtils.createBinary(hop, new LiteralOp(cnt), Hop.OpOp2.POW);
 		pow.setVisited();
 		return pow;
 	}
@@ -162,8 +246,8 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
 	/**
 	 * A Comparator that orders Hops by their data type, dimention, and sparsity.
 	 * The order is as follows:
-	 * 		scalars > col vectors > row vectors >
-	 *      non-vector matrices ordered by sparsity (higher nnz first, unknown sparsity last) >
+	 * 		scalars < col vectors < row vectors <
+	 *      non-vector matrices ordered by sparsity (higher nnz last, unknown sparsity last) >
 	 *      other data types.
 	 * Disambiguate by Hop ID.
 	 */
@@ -172,11 +256,11 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
 		{
 			for (int i = 0, valuesLength = Expression.DataType.values().length; i < valuesLength; i++)
 				switch(Expression.DataType.values()[i]) {
-				case SCALAR: orderDataType[i] = 4; break;
-				case MATRIX: orderDataType[i] = 3; break;
+				case SCALAR: orderDataType[i] = 0; break;
+				case MATRIX: orderDataType[i] = 1; break;
 				case FRAME:  orderDataType[i] = 2; break;
-				case OBJECT: orderDataType[i] = 1; break;
-				case UNKNOWN:orderDataType[i] = 0; break;
+				case OBJECT: orderDataType[i] = 3; break;
+				case UNKNOWN:orderDataType[i] = 4; break;
 				}
 		}
 
@@ -190,15 +274,15 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
 			case MATRIX:
 				// two matrices; check for vectors
 				if (o1.getDim2() == 1) { // col vector
-						if (o2.getDim2() != 1) return 1; // col vectors are greatest of matrices
+						if (o2.getDim2() != 1) return -1; // col vectors are greatest of matrices
 						return compareBySparsityThenId(o1, o2); // both col vectors
 				} else if (o2.getDim2() == 1) { // 2 is col vector; 1 is not
-						return -1; // col vectors are the greatest matrices
+						return 1; // col vectors are the greatest matrices
 				} else if (o1.getDim1() == 1) { // row vector
-						if (o2.getDim1() != 1) return 1; // row vectors greater than non-vectors
+						if (o2.getDim1() != 1) return -1; // row vectors greater than non-vectors
 						return compareBySparsityThenId(o1, o2); // both row vectors
 				} else if (o2.getDim1() == 1) { // 2 is row vector; 1 is not
-						return 1; // col vectors greater than non-vectors
+						return 1; // row vectors greater than non-vectors
 				} else { // both non-vectors
 						return compareBySparsityThenId(o1, o2);
 				}
@@ -209,10 +293,10 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
 		private int compareBySparsityThenId(final Hop o1, final Hop o2) {
 			// the hop with more nnz is first; unknown nnz (-1) last
 			final int c = Long.compare(o1.getNnz(), o2.getNnz());
-			if (c != 0) return c;
+			if (c != 0) return -c;
 			return Long.compare(o1.getHopID(), o2.getHopID());
 		}
-	}.reversed();
+	};
 
 	/**
 	 * Check if a node has a parent that is not in the set of emults. Recursively check children who are also emults.

http://git-wip-us.apache.org/repos/asf/systemml/blob/999fdfbc/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationAllTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationAllTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationAllTest.java
new file mode 100644
index 0000000..ba5c78d
--- /dev/null
+++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationAllTest.java
@@ -0,0 +1,134 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.test.integration.functions.misc;
+
+import java.util.HashMap;
+
+import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
+import org.apache.sysml.hops.OptimizerUtils;
+import org.apache.sysml.lops.LopProperties.ExecType;
+import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysml.test.integration.AutomatedTestBase;
+import org.apache.sysml.test.integration.TestConfiguration;
+import org.apache.sysml.test.utils.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+
+/**
+ * Test rewriting `2*X*3*v*5*w*4*z*5*Y*2*v*2*X`, where `v` and `z` are row vectors and `w` is a column vector,
+ * successfully rewrites to `Y*(X^2)*z*(v^2)*w*2400`.
+ */
+public class RewriteElementwiseMultChainOptimizationAllTest extends AutomatedTestBase
+{
+	private static final String TEST_NAME1 = "RewriteEMultChainOpAll";
+	private static final String TEST_DIR = "functions/misc/";
+	private static final String TEST_CLASS_DIR = TEST_DIR + RewriteElementwiseMultChainOptimizationAllTest.class.getSimpleName() + "/";
+	
+	private static final int rows = 123;
+	private static final int cols = 321;
+	private static final double eps = Math.pow(10, -10);
+	
+	@Override
+	public void setUp() {
+		TestUtils.clearAssertionInformation();
+		addTestConfiguration( TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) );
+	}
+
+	@Test
+	public void testMatrixMultChainOptNoRewritesCP() {
+		testRewriteMatrixMultChainOp(TEST_NAME1, false, ExecType.CP);
+	}
+	
+	@Test
+	public void testMatrixMultChainOptNoRewritesSP() {
+		testRewriteMatrixMultChainOp(TEST_NAME1, false, ExecType.SPARK);
+	}
+	
+	@Test
+	public void testMatrixMultChainOptRewritesCP() {
+		testRewriteMatrixMultChainOp(TEST_NAME1, true, ExecType.CP);
+	}
+	
+	@Test
+	public void testMatrixMultChainOptRewritesSP() {
+		testRewriteMatrixMultChainOp(TEST_NAME1, true, ExecType.SPARK);
+	}
+
+	private void testRewriteMatrixMultChainOp(String testname, boolean rewrites, ExecType et)
+	{	
+		RUNTIME_PLATFORM platformOld = rtplatform;
+		switch( et ){
+			case MR: rtplatform = RUNTIME_PLATFORM.HADOOP; break;
+			case SPARK: rtplatform = RUNTIME_PLATFORM.SPARK; break;
+			default: rtplatform = RUNTIME_PLATFORM.HYBRID_SPARK; break;
+		}
+		
+		boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+		if( rtplatform == RUNTIME_PLATFORM.SPARK )
+			DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+		
+		boolean rewritesOld = OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES;
+		OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewrites;
+		
+		try
+		{
+			TestConfiguration config = getTestConfiguration(testname);
+			loadTestConfiguration(config);
+			
+			String HOME = SCRIPT_DIR + TEST_DIR;
+			fullDMLScriptName = HOME + testname + ".dml";
+			programArgs = new String[] { "-explain", "hops", "-stats", "-args", input("X"), input("Y"), input("v"), input("z"), input("w"), output("R") };
+			fullRScriptName = HOME + testname + ".R";
+			rCmd = getRCmd(inputDir(), expectedDir());
+
+			double Xsparsity = 0.8, Ysparsity = 0.6;
+			double[][] X = getRandomMatrix(rows, cols, -1, 1, Xsparsity, 7);
+			double[][] Y = getRandomMatrix(rows, cols, -1, 1, Ysparsity, 3);
+			double[][] z = getRandomMatrix(1, cols, -1, 1, Ysparsity, 5);
+			double[][] v = getRandomMatrix(1, cols, -1, 1, Xsparsity, 4);
+			double[][] w = getRandomMatrix(rows, 1, -1, 1, Ysparsity, 6);
+			writeInputMatrixWithMTD("X", X, true);
+			writeInputMatrixWithMTD("Y", Y, true);
+			writeInputMatrixWithMTD("z", z, true);
+			writeInputMatrixWithMTD("v", v, true);
+			writeInputMatrixWithMTD("w", w, true);
+
+			//execute tests
+			runTest(true, false, null, -1); 
+			runRScript(true); 
+			
+			//compare matrices 
+			HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromHDFS("R");
+			HashMap<CellIndex, Double> rfile  = readRMatrixFromFS("R");
+			TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R");
+			
+			//check for presence of power operator, if we did a rewrite
+			if( rewrites ) {
+				Assert.assertTrue(heavyHittersContainsSubString("^2"));
+			}
+		}
+		finally {
+			OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewritesOld;
+			rtplatform = platformOld;
+			DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+		}
+	}
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/999fdfbc/src/test/scripts/functions/misc/RewriteEMultChainOpAll.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteEMultChainOpAll.R b/src/test/scripts/functions/misc/RewriteEMultChainOpAll.R
new file mode 100644
index 0000000..20f76c2
--- /dev/null
+++ b/src/test/scripts/functions/misc/RewriteEMultChainOpAll.R
@@ -0,0 +1,37 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+args <- commandArgs(TRUE)
+# args[1]=""
+options(digits=22)
+library("Matrix")
+library("matrixStats")
+
+X = as.matrix(readMM(paste(args[1], "X.mtx", sep="")))
+Y = as.matrix(readMM(paste(args[1], "Y.mtx", sep="")))
+v = as.matrix(readMM(paste(args[1], "v.mtx", sep="")))
+z = as.matrix(readMM(paste(args[1], "z.mtx", sep="")))
+w = as.matrix(readMM(paste(args[1], "w.mtx", sep="")))
+
+R = 2* X *3* X *5* Y *4*5*2*2* (matrix(1,length(w),1)%*%z) * (matrix(1,length(w),1)%*%v)^2 * (w%*%matrix(1,1,length(v)))
+
+writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep=""));

http://git-wip-us.apache.org/repos/asf/systemml/blob/999fdfbc/src/test/scripts/functions/misc/RewriteEMultChainOpAll.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteEMultChainOpAll.dml b/src/test/scripts/functions/misc/RewriteEMultChainOpAll.dml
new file mode 100644
index 0000000..90f9242
--- /dev/null
+++ b/src/test/scripts/functions/misc/RewriteEMultChainOpAll.dml
@@ -0,0 +1,31 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+X = read($1);
+Y = read($2);
+v = read($3);
+z = read($4);
+w = read($5);
+
+R = 2* X *3* v *5* w *4* z *5* Y *2* v *2* X
+
+write(R, $6);
\ No newline at end of file


[15/23] systemml git commit: Change order of row and col vectors, so as to create inner products rather than outer products.

Posted by mb...@apache.org.
Change order of row and col vectors, so as to create inner products rather than outer products.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/6c3e1c5b
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/6c3e1c5b
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/6c3e1c5b

Branch: refs/heads/master
Commit: 6c3e1c5bad30dc8f11ff9d3f412ce68873c37202
Parents: 04f692d
Author: Dylan Hutchison <dh...@cs.washington.edu>
Authored: Tue Jul 11 20:08:22 2017 -0700
Committer: Dylan Hutchison <dh...@cs.washington.edu>
Committed: Tue Jul 11 20:08:22 2017 -0700

----------------------------------------------------------------------
 ...RewriteElementwiseMultChainOptimization.java | 24 ++++++++++----------
 1 file changed, 12 insertions(+), 12 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/6c3e1c5b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
index 9ca0932..9cc8fcd 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
@@ -162,7 +162,7 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
 	/**
 	 * A Comparator that orders Hops by their data type, dimention, and sparsity.
 	 * The order is as follows:
-	 * 		scalars > row vectors > col vectors >
+	 * 		scalars > col vectors > row vectors >
 	 *      non-vector matrices ordered by sparsity (higher nnz first, unknown sparsity last) >
 	 *      other data types.
 	 * Disambiguate by Hop ID.
@@ -181,23 +181,23 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
 		}
 
 		@Override
-		public final int compare(Hop o1, Hop o2) {
-			int c = Integer.compare(orderDataType[o1.getDataType().ordinal()], orderDataType[o2.getDataType().ordinal()]);
+		public final int compare(final Hop o1, final Hop o2) {
+			final int c = Integer.compare(orderDataType[o1.getDataType().ordinal()], orderDataType[o2.getDataType().ordinal()]);
 			if (c != 0) return c;
 
 			// o1 and o2 have the same data type
 			switch (o1.getDataType()) {
 			case MATRIX:
 				// two matrices; check for vectors
-				if (o1.getDim1() == 1) { // row vector
-						if (o2.getDim1() != 1) return 1; // row vectors are greatest of matrices
-						return compareBySparsityThenId(o1, o2); // both row vectors
-				} else if (o2.getDim1() == 1) { // 2 is row vector; 1 is not
-						return -1; // row vectors are the greatest matrices
-				} else if (o1.getDim2() == 1) { // col vector
-						if (o2.getDim2() != 1) return 1; // col vectors greater than non-vectors
+				if (o1.getDim2() == 1) { // col vector
+						if (o2.getDim2() != 1) return 1; // col vectors are greatest of matrices
 						return compareBySparsityThenId(o1, o2); // both col vectors
 				} else if (o2.getDim2() == 1) { // 2 is col vector; 1 is not
+						return -1; // col vectors are the greatest matrices
+				} else if (o1.getDim1() == 1) { // row vector
+						if (o2.getDim1() != 1) return 1; // row vectors greater than non-vectors
+						return compareBySparsityThenId(o1, o2); // both row vectors
+				} else if (o2.getDim1() == 1) { // 2 is row vector; 1 is not
 						return 1; // col vectors greater than non-vectors
 				} else { // both non-vectors
 						return compareBySparsityThenId(o1, o2);
@@ -206,9 +206,9 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
 				return Long.compare(o1.getHopID(), o2.getHopID());
 			}
 		}
-		private int compareBySparsityThenId(Hop o1, Hop o2) {
+		private int compareBySparsityThenId(final Hop o1, final Hop o2) {
 			// the hop with more nnz is first; unknown nnz (-1) last
-			int c = Long.compare(o1.getNnz(), o2.getNnz());
+			final int c = Long.compare(o1.getNnz(), o2.getNnz());
 			if (c != 0) return c;
 			return Long.compare(o1.getHopID(), o2.getHopID());
 		}


[02/23] systemml git commit: Fix RewriteEMult comparator. Add tests.

Posted by mb...@apache.org.
Fix RewriteEMult comparator. Add tests.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/eb0599df
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/eb0599df
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/eb0599df

Branch: refs/heads/master
Commit: eb0599df4c3bcca15531b85a3d870a26e4653179
Parents: 7d57883
Author: Dylan Hutchison <dh...@cs.washington.edu>
Authored: Fri Jun 9 11:18:32 2017 -0700
Committer: Dylan Hutchison <dh...@cs.washington.edu>
Committed: Sun Jun 18 17:43:15 2017 -0700

----------------------------------------------------------------------
 .../org/apache/sysml/hops/OptimizerUtils.java   |   9 +-
 .../sysml/hops/rewrite/ProgramRewriter.java     |   3 +-
 .../apache/sysml/hops/rewrite/RewriteEMult.java |  10 +-
 .../functions/misc/RewriteEMultChainTest.java   | 127 +++++++++
 .../ternary/ABATernaryAggregateTest.java        | 268 +++++++++++++++++++
 .../functions/misc/RewriteEMultChainOp.R        |  33 +++
 .../functions/misc/RewriteEMultChainOp.dml      |  28 ++
 .../functions/ternary/ABATernaryAggregateC.R    |  32 +++
 .../functions/ternary/ABATernaryAggregateC.dml  |  30 +++
 .../functions/ternary/ABATernaryAggregateRC.R   |  33 +++
 .../functions/ternary/ABATernaryAggregateRC.dml |  30 +++
 11 files changed, 597 insertions(+), 6 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/eb0599df/src/main/java/org/apache/sysml/hops/OptimizerUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/OptimizerUtils.java b/src/main/java/org/apache/sysml/hops/OptimizerUtils.java
index a40e36c..2a76d07 100644
--- a/src/main/java/org/apache/sysml/hops/OptimizerUtils.java
+++ b/src/main/java/org/apache/sysml/hops/OptimizerUtils.java
@@ -110,8 +110,13 @@ public class OptimizerUtils
 	 */
 	public static boolean ALLOW_CONSTANT_FOLDING = true;
 	
-	public static boolean ALLOW_ALGEBRAIC_SIMPLIFICATION = true; 
-	public static boolean ALLOW_OPERATOR_FUSION = true; 
+	public static boolean ALLOW_ALGEBRAIC_SIMPLIFICATION = true;
+	/**
+	 * Enables rewriting chains of element-wise multiplies that contain the same multiplicand more than once, as in
+	 * `A*B*A ==> (A^2)*B`.
+	 */
+	public static boolean ALLOW_EMULT_CHAIN_REWRITE = true;
+	public static boolean ALLOW_OPERATOR_FUSION = true;
 	
 	/**
 	 * Enables if-else branch removal for constant predicates (original literals or 

http://git-wip-us.apache.org/repos/asf/systemml/blob/eb0599df/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
index 8573dd7..b6aab38 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
@@ -96,7 +96,8 @@ public class ProgramRewriter
 			_dagRuleSet.add(     new RewriteRemoveUnnecessaryCasts()             );		
 			if( OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION )
 				_dagRuleSet.add( new RewriteCommonSubexpressionElimination()     );
-			_dagRuleSet.add( new RewriteEMult()                                  ); //dependency: cse
+			if ( OptimizerUtils.ALLOW_EMULT_CHAIN_REWRITE )
+				_dagRuleSet.add( new RewriteEMult()                              ); //dependency: cse
 			if( OptimizerUtils.ALLOW_CONSTANT_FOLDING )
 				_dagRuleSet.add( new RewriteConstantFolding()                    ); //dependency: cse
 			if( OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION )

http://git-wip-us.apache.org/repos/asf/systemml/blob/eb0599df/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
index 47c32a9..2c9e5cb 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
@@ -50,7 +50,6 @@ public class RewriteEMult extends HopRewriteRule {
 	public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) throws HopsException {
 		if( roots == null )
 			return null;
-
 		for( int i=0; i<roots.size(); i++ ) {
 			Hop h = roots.get(i);
 			roots.set(i, rule_RewriteEMult(h));
@@ -83,6 +82,7 @@ public class RewriteEMult extends HopRewriteRule {
 			final Set<BinaryOp> emults = new HashSet<>();
 			final Multiset<Hop> leaves = HashMultiset.create();
 			findEMultsAndLeaves(r, emults, leaves);
+
 			// 2. Ensure it is profitable to do a rewrite.
 			if (isOptimizable(leaves)) {
 				// 3. Check for foreign parents.
@@ -93,8 +93,12 @@ public class RewriteEMult extends HopRewriteRule {
 				if (okay) {
 					// 4. Construct replacement EMults for the leaves
 					final Hop replacement = constructReplacement(leaves);
-
 					// 5. Replace root with replacement
+					if (LOG.isDebugEnabled())
+						LOG.debug(String.format(
+								"Element-wise multiply chain rewrite of %d e-mults at sub-dag %d to new sub-dag %d",
+								emults.size(), root.getHopID(), replacement.getHopID()));
+					replacement.setVisited();
 					return HopRewriteUtils.replaceHop(root, replacement);
 				}
 			}
@@ -141,7 +145,7 @@ public class RewriteEMult extends HopRewriteRule {
 		return HopRewriteUtils.createBinary(hop, new LiteralOp(cnt), Hop.OpOp2.POW);
 	}
 
-	private static Comparator<Hop> compareByDataType = Comparator.comparing(Hop::getDataType);
+	private static Comparator<Hop> compareByDataType = Comparator.comparing(Hop::getDataType).thenComparing(Object::hashCode);
 
 	private static boolean checkForeignParent(final Set<BinaryOp> emults, final BinaryOp child) {
 		final ArrayList<Hop> parents = child.getParent();

http://git-wip-us.apache.org/repos/asf/systemml/blob/eb0599df/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java
new file mode 100644
index 0000000..e076c95
--- /dev/null
+++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java
@@ -0,0 +1,127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.test.integration.functions.misc;
+
+import java.util.HashMap;
+
+import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
+import org.apache.sysml.hops.OptimizerUtils;
+import org.apache.sysml.lops.LopProperties.ExecType;
+import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysml.test.integration.AutomatedTestBase;
+import org.apache.sysml.test.integration.TestConfiguration;
+import org.apache.sysml.test.utils.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+
+/**
+ * Test whether `A*B*A` successfully rewrites to `(A^2)*B`.
+ */
+public class RewriteEMultChainTest extends AutomatedTestBase
+{
+	private static final String TEST_NAME1 = "RewriteEMultChainOp";
+	private static final String TEST_DIR = "functions/misc/";
+	private static final String TEST_CLASS_DIR = TEST_DIR + RewriteEMultChainTest.class.getSimpleName() + "/";
+	
+	private static final int rows = 123;
+	private static final int cols = 321;
+	private static final double eps = Math.pow(10, -10);
+	
+	@Override
+	public void setUp() {
+		TestUtils.clearAssertionInformation();
+		addTestConfiguration( TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) );
+	}
+
+	@Test
+	public void testMatrixMultChainOptNoRewritesCP() {
+		testRewriteMatrixMultChainOp(TEST_NAME1, false, ExecType.CP);
+	}
+	
+	@Test
+	public void testMatrixMultChainOptNoRewritesSP() {
+		testRewriteMatrixMultChainOp(TEST_NAME1, false, ExecType.SPARK);
+	}
+	
+	@Test
+	public void testMatrixMultChainOptRewritesCP() {
+		testRewriteMatrixMultChainOp(TEST_NAME1, true, ExecType.CP);
+	}
+	
+	@Test
+	public void testMatrixMultChainOptRewritesSP() {
+		testRewriteMatrixMultChainOp(TEST_NAME1, true, ExecType.SPARK);
+	}
+
+	private void testRewriteMatrixMultChainOp(String testname, boolean rewrites, ExecType et)
+	{	
+		RUNTIME_PLATFORM platformOld = rtplatform;
+		switch( et ){
+			case MR: rtplatform = RUNTIME_PLATFORM.HADOOP; break;
+			case SPARK: rtplatform = RUNTIME_PLATFORM.SPARK; break;
+			default: rtplatform = RUNTIME_PLATFORM.HYBRID_SPARK; break;
+		}
+		
+		boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+		if( rtplatform == RUNTIME_PLATFORM.SPARK )
+			DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+		
+		boolean rewritesOld = OptimizerUtils.ALLOW_EMULT_CHAIN_REWRITE;
+		OptimizerUtils.ALLOW_EMULT_CHAIN_REWRITE = rewrites;
+		
+		try
+		{
+			TestConfiguration config = getTestConfiguration(testname);
+			loadTestConfiguration(config);
+			
+			String HOME = SCRIPT_DIR + TEST_DIR;
+			fullDMLScriptName = HOME + testname + ".dml";
+			programArgs = new String[]{ "-explain", "hops", "-stats", 
+				"-args", input("X"), input("Y"), output("R") };
+			fullRScriptName = HOME + testname + ".R";
+			rCmd = getRCmd(inputDir(), expectedDir());			
+
+			double[][] X = getRandomMatrix(rows, cols, -1, 1, 0.97d, 7);
+			double[][] Y = getRandomMatrix(rows, cols, -1, 1, 0.9d, 3);
+			writeInputMatrixWithMTD("X", X, true);
+			writeInputMatrixWithMTD("Y", Y, true);
+			
+			//execute tests
+			runTest(true, false, null, -1); 
+			runRScript(true); 
+			
+			//compare matrices 
+			HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromHDFS("R");
+			HashMap<CellIndex, Double> rfile  = readRMatrixFromFS("R");
+			TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R");
+			
+			//check for presence of power operator, if we did a rewrite
+			if( rewrites ) {
+				Assert.assertTrue(heavyHittersContainsSubString("^2"));
+			}
+		}
+		finally {
+			OptimizerUtils.ALLOW_EMULT_CHAIN_REWRITE = rewritesOld;
+			rtplatform = platformOld;
+			DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+		}
+	}	
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/eb0599df/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java b/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java
new file mode 100644
index 0000000..198e9f4
--- /dev/null
+++ b/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java
@@ -0,0 +1,268 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.test.integration.functions.ternary;
+
+import java.util.HashMap;
+
+import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
+import org.apache.sysml.hops.OptimizerUtils;
+import org.apache.sysml.lops.LopProperties.ExecType;
+import org.apache.sysml.runtime.instructions.Instruction;
+import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysml.test.integration.AutomatedTestBase;
+import org.apache.sysml.test.integration.TestConfiguration;
+import org.apache.sysml.test.utils.TestUtils;
+import org.apache.sysml.utils.Statistics;
+import org.junit.Assert;
+import org.junit.Test;
+
+/**
+ * Similar to {@link TernaryAggregateTest} except that it tests `sum(A*B*A)`.
+ * Checks compatibility with {@link org.apache.sysml.hops.rewrite.RewriteEMult}.
+ */
+public class ABATernaryAggregateTest extends AutomatedTestBase
+{
+	private final static String TEST_NAME1 = "ABATernaryAggregateRC";
+	private final static String TEST_NAME2 = "ABATernaryAggregateC";
+	
+	private final static String TEST_DIR = "functions/ternary/";
+	private final static String TEST_CLASS_DIR = TEST_DIR + ABATernaryAggregateTest.class.getSimpleName() + "/";
+	private final static double eps = 1e-8;
+	
+	private final static int rows = 1111;
+	private final static int cols = 1011;
+	
+	private final static double sparsity1 = 0.7;
+	private final static double sparsity2 = 0.3;
+	
+	@Override
+	public void setUp() {
+		TestUtils.clearAssertionInformation();
+		addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) ); 
+		addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] { "R" }) ); 
+	}
+
+	@Test
+	public void testTernaryAggregateRCDenseVectorCP() {
+		runTernaryAggregateTest(TEST_NAME1, false, true, true, ExecType.CP);
+	}
+	
+	@Test
+	public void testTernaryAggregateRCSparseVectorCP() {
+		runTernaryAggregateTest(TEST_NAME1, true, true, true, ExecType.CP);
+	}
+	
+	@Test
+	public void testTernaryAggregateRCDenseMatrixCP() {
+		runTernaryAggregateTest(TEST_NAME1, false, false, true, ExecType.CP);
+	}
+	
+	@Test
+	public void testTernaryAggregateRCSparseMatrixCP() {
+		runTernaryAggregateTest(TEST_NAME1, true, false, true, ExecType.CP);
+	}
+	
+	@Test
+	public void testTernaryAggregateRCDenseVectorSP() {
+		runTernaryAggregateTest(TEST_NAME1, false, true, true, ExecType.SPARK);
+	}
+	
+	@Test
+	public void testTernaryAggregateRCSparseVectorSP() {
+		runTernaryAggregateTest(TEST_NAME1, true, true, true, ExecType.SPARK);
+	}
+	
+	@Test
+	public void testTernaryAggregateRCDenseMatrixSP() {
+		runTernaryAggregateTest(TEST_NAME1, false, false, true, ExecType.SPARK);
+	}
+	
+	@Test
+	public void testTernaryAggregateRCSparseMatrixSP() {
+		runTernaryAggregateTest(TEST_NAME1, true, false, true, ExecType.SPARK);
+	}
+	
+	@Test
+	public void testTernaryAggregateRCDenseVectorMR() {
+		runTernaryAggregateTest(TEST_NAME1, false, true, true, ExecType.MR);
+	}
+	
+	@Test
+	public void testTernaryAggregateRCSparseVectorMR() {
+		runTernaryAggregateTest(TEST_NAME1, true, true, true, ExecType.MR);
+	}
+	
+	@Test
+	public void testTernaryAggregateRCDenseMatrixMR() {
+		runTernaryAggregateTest(TEST_NAME1, false, false, true, ExecType.MR);
+	}
+	
+	@Test
+	public void testTernaryAggregateRCSparseMatrixMR() {
+		runTernaryAggregateTest(TEST_NAME1, true, false, true, ExecType.MR);
+	}
+	
+	@Test
+	public void testTernaryAggregateCDenseVectorCP() {
+		runTernaryAggregateTest(TEST_NAME2, false, true, true, ExecType.CP);
+	}
+	
+	@Test
+	public void testTernaryAggregateCSparseVectorCP() {
+		runTernaryAggregateTest(TEST_NAME2, true, true, true, ExecType.CP);
+	}
+	
+	@Test
+	public void testTernaryAggregateCDenseMatrixCP() {
+		runTernaryAggregateTest(TEST_NAME2, false, false, true, ExecType.CP);
+	}
+	
+	@Test
+	public void testTernaryAggregateCSparseMatrixCP() {
+		runTernaryAggregateTest(TEST_NAME2, true, false, true, ExecType.CP);
+	}
+	
+	@Test
+	public void testTernaryAggregateCDenseVectorSP() {
+		runTernaryAggregateTest(TEST_NAME2, false, true, true, ExecType.SPARK);
+	}
+	
+	@Test
+	public void testTernaryAggregateCSparseVectorSP() {
+		runTernaryAggregateTest(TEST_NAME2, true, true, true, ExecType.SPARK);
+	}
+	
+	@Test
+	public void testTernaryAggregateCDenseMatrixSP() {
+		runTernaryAggregateTest(TEST_NAME2, false, false, true, ExecType.SPARK);
+	}
+	
+	@Test
+	public void testTernaryAggregateCSparseMatrixSP() {
+		runTernaryAggregateTest(TEST_NAME2, true, false, true, ExecType.SPARK);
+	}
+	
+	//additional tests to check default without rewrites
+	
+	@Test
+	public void testTernaryAggregateRCDenseVectorCPNoRewrite() {
+		runTernaryAggregateTest(TEST_NAME1, false, true, false, ExecType.CP);
+	}
+	
+	@Test
+	public void testTernaryAggregateRCSparseVectorCPNoRewrite() {
+		runTernaryAggregateTest(TEST_NAME1, true, true, false, ExecType.CP);
+	}
+	
+	@Test
+	public void testTernaryAggregateRCDenseMatrixCPNoRewrite() {
+		runTernaryAggregateTest(TEST_NAME1, false, false, false, ExecType.CP);
+	}
+	
+	@Test
+	public void testTernaryAggregateRCSparseMatrixCPNoRewrite() {
+		runTernaryAggregateTest(TEST_NAME1, true, false, false, ExecType.CP);
+	}
+	
+	@Test
+	public void testTernaryAggregateCDenseVectorCPNoRewrite() {
+		runTernaryAggregateTest(TEST_NAME2, false, true, false, ExecType.CP);
+	}
+	
+	@Test
+	public void testTernaryAggregateCSparseVectorCPNoRewrite() {
+		runTernaryAggregateTest(TEST_NAME2, true, true, false, ExecType.CP);
+	}
+	
+	@Test
+	public void testTernaryAggregateCDenseMatrixCPNoRewrite() {
+		runTernaryAggregateTest(TEST_NAME2, false, false, false, ExecType.CP);
+	}
+	
+	@Test
+	public void testTernaryAggregateCSparseMatrixCPNoRewrite() {
+		runTernaryAggregateTest(TEST_NAME2, true, false, false, ExecType.CP);
+	}
+	
+	
+	
+	private void runTernaryAggregateTest(String testname, boolean sparse, boolean vectors, boolean rewrites, ExecType et)
+	{
+		//rtplatform for MR
+		RUNTIME_PLATFORM platformOld = rtplatform;
+		switch( et ){
+			case MR: rtplatform = RUNTIME_PLATFORM.HADOOP; break;
+			case SPARK: rtplatform = RUNTIME_PLATFORM.SPARK; break;
+			default: rtplatform = RUNTIME_PLATFORM.HYBRID; break;
+		}
+	
+		boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+		if( rtplatform == RUNTIME_PLATFORM.SPARK )
+			DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+	
+		boolean rewritesOld = OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES,
+				rewritesOldEmult = OptimizerUtils.ALLOW_EMULT_CHAIN_REWRITE;
+		
+		try {
+			TestConfiguration config = getTestConfiguration(testname);
+			loadTestConfiguration(config);
+			
+			OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewrites;
+			OptimizerUtils.ALLOW_EMULT_CHAIN_REWRITE = rewrites;
+
+			String HOME = SCRIPT_DIR + TEST_DIR;
+			fullDMLScriptName = HOME + testname + ".dml";
+			programArgs = new String[]{"-explain","hops","-stats","-args", input("A"), output("R")};
+			
+			fullRScriptName = HOME + testname + ".R";
+			rCmd = "Rscript" + " " + fullRScriptName + " " + 
+				inputDir() + " " + expectedDir();
+	
+			//generate actual dataset
+			double sparsity = sparse ? sparsity2 : sparsity1;
+			double[][] A = getRandomMatrix(vectors ? rows*cols : rows, 
+					vectors ? 1 : cols, 0, 1, sparsity, 17); 
+			writeInputMatrixWithMTD("A", A, true);
+			
+			//run test cases
+			runTest(true, false, null, -1); 
+			runRScript(true); 
+			
+			//compare output matrices 
+			HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromHDFS("R");
+			HashMap<CellIndex, Double> rfile  = readRMatrixFromFS("R");
+			TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R");
+			
+			//check for rewritten patterns in statistics output
+			if( rewrites && et != ExecType.MR ) {
+				String opcode = ((et == ExecType.SPARK) ? Instruction.SP_INST_PREFIX : "") + 
+					(((testname.equals(TEST_NAME1) || vectors ) ? "tak+*" : "tack+*"));
+				Assert.assertTrue(Statistics.getCPHeavyHitterOpCodes().contains(opcode));
+			}
+		}
+		finally {
+			rtplatform = platformOld;
+			DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+			OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewritesOld;
+			OptimizerUtils.ALLOW_EMULT_CHAIN_REWRITE = rewritesOldEmult;
+		}
+	}
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/eb0599df/src/test/scripts/functions/misc/RewriteEMultChainOp.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteEMultChainOp.R b/src/test/scripts/functions/misc/RewriteEMultChainOp.R
new file mode 100644
index 0000000..6d94cc8
--- /dev/null
+++ b/src/test/scripts/functions/misc/RewriteEMultChainOp.R
@@ -0,0 +1,33 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+args <- commandArgs(TRUE)
+options(digits=22)
+library("Matrix")
+library("matrixStats")
+
+X = as.matrix(readMM(paste(args[1], "X.mtx", sep="")))
+Y = as.matrix(readMM(paste(args[1], "Y.mtx", sep="")))
+
+R = X * Y * X;
+
+writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep="")); 

http://git-wip-us.apache.org/repos/asf/systemml/blob/eb0599df/src/test/scripts/functions/misc/RewriteEMultChainOp.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteEMultChainOp.dml b/src/test/scripts/functions/misc/RewriteEMultChainOp.dml
new file mode 100644
index 0000000..3992403
--- /dev/null
+++ b/src/test/scripts/functions/misc/RewriteEMultChainOp.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+X = read($1);
+Y = read($2);
+
+R = X * Y * X;
+
+write(R, $3);
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/systemml/blob/eb0599df/src/test/scripts/functions/ternary/ABATernaryAggregateC.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/ternary/ABATernaryAggregateC.R b/src/test/scripts/functions/ternary/ABATernaryAggregateC.R
new file mode 100644
index 0000000..9601089
--- /dev/null
+++ b/src/test/scripts/functions/ternary/ABATernaryAggregateC.R
@@ -0,0 +1,32 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+args <- commandArgs(TRUE)
+options(digits=22)
+
+library("Matrix")
+
+A = as.matrix(readMM(paste(args[1], "A.mtx", sep="")))
+B = A * 2;
+
+R = t(as.matrix(colSums(A * B * A)));
+
+writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep="")); 

http://git-wip-us.apache.org/repos/asf/systemml/blob/eb0599df/src/test/scripts/functions/ternary/ABATernaryAggregateC.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/ternary/ABATernaryAggregateC.dml b/src/test/scripts/functions/ternary/ABATernaryAggregateC.dml
new file mode 100644
index 0000000..78285af
--- /dev/null
+++ b/src/test/scripts/functions/ternary/ABATernaryAggregateC.dml
@@ -0,0 +1,30 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+A = read($1);
+B = A * 2;
+C = A * 3;
+
+if(1==1){}
+
+R = colSums(A * B * A);
+
+write(R, $2);

http://git-wip-us.apache.org/repos/asf/systemml/blob/eb0599df/src/test/scripts/functions/ternary/ABATernaryAggregateRC.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/ternary/ABATernaryAggregateRC.R b/src/test/scripts/functions/ternary/ABATernaryAggregateRC.R
new file mode 100644
index 0000000..6552c7e
--- /dev/null
+++ b/src/test/scripts/functions/ternary/ABATernaryAggregateRC.R
@@ -0,0 +1,33 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+args <- commandArgs(TRUE)
+options(digits=22)
+
+library("Matrix")
+
+A = as.matrix(readMM(paste(args[1], "A.mtx", sep="")))
+B = A * 2;
+
+s = sum(A * B * A);
+R = as.matrix(s);
+
+writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep="")); 

http://git-wip-us.apache.org/repos/asf/systemml/blob/eb0599df/src/test/scripts/functions/ternary/ABATernaryAggregateRC.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/ternary/ABATernaryAggregateRC.dml b/src/test/scripts/functions/ternary/ABATernaryAggregateRC.dml
new file mode 100644
index 0000000..965c8d3
--- /dev/null
+++ b/src/test/scripts/functions/ternary/ABATernaryAggregateRC.dml
@@ -0,0 +1,30 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+A = read($1);
+B = A * 2;
+
+if(1==1){}
+
+s = sum(A * B * A);
+R = as.matrix(s);
+
+write(R, $2);
\ No newline at end of file


[09/23] systemml git commit: Review comments, part 1

Posted by mb...@apache.org.
Review comments, part 1


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/b94557fd
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/b94557fd
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/b94557fd

Branch: refs/heads/master
Commit: b94557fd2c90c591179cdbf05a32242fadc36448
Parents: d88f867
Author: Dylan Hutchison <dh...@cs.washington.edu>
Authored: Sun Jun 11 00:35:52 2017 -0700
Committer: Dylan Hutchison <dh...@cs.washington.edu>
Committed: Sun Jun 18 17:43:37 2017 -0700

----------------------------------------------------------------------
 .../java/org/apache/sysml/hops/AggUnaryOp.java  |  21 +-
 .../org/apache/sysml/hops/OptimizerUtils.java   |   5 -
 .../sysml/hops/rewrite/HopRewriteUtils.java     |   3 +-
 .../sysml/hops/rewrite/ProgramRewriter.java     |   4 +-
 .../RewriteAlgebraicSimplificationDynamic.java  |  16 +-
 .../apache/sysml/hops/rewrite/RewriteEMult.java | 284 -------------------
 ...RewriteElementwiseMultChainOptimization.java | 281 ++++++++++++++++++
 .../functions/misc/RewriteEMultChainTest.java   | 127 ---------
 ...ementwiseMultChainOptimizationChainTest.java | 127 +++++++++
 .../ternary/ABATernaryAggregateTest.java        |   9 +-
 .../functions/misc/ZPackageSuite.java           |   1 +
 .../functions/ternary/ZPackageSuite.java        |   3 +-
 12 files changed, 436 insertions(+), 445 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/b94557fd/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
index 300a20c..a207831 100644
--- a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
+++ b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
@@ -497,7 +497,7 @@ public class AggUnaryOp extends Hop implements MultiThreadedHop
 				if (binput1.getOp() == OpOp2.POW
 						&& binput1.getInput().get(1) instanceof LiteralOp) {
 					LiteralOp lit = (LiteralOp)binput1.getInput().get(1);
-					ret = lit.getLongValue() == 3;
+					ret = HopRewriteUtils.getIntValueSafe(lit) == 3;
 				}
 				else if (binput1.getOp() == OpOp2.MULT
 						// As unary agg instruction is not implemented in MR and since MR is in maintenance mode, postponed it.
@@ -640,15 +640,10 @@ public class AggUnaryOp extends Hop implements MultiThreadedHop
 		boolean handled = false;
 
 		if (input1.getOp() == OpOp2.POW) {
-			switch ((int)((LiteralOp)input12).getLongValue()) {
-			case 3:
-				in1 = input11.constructLops();
-				in2 = in1;
-				in3 = in1;
-				break;
-			default:
-				throw new AssertionError("unreachable; only applies to power 3");
-			}
+			assert(HopRewriteUtils.isLiteralOfValue(input12, 3)) : "this case can only occur with a power of 3";
+			in1 = input11.constructLops();
+			in2 = in1;
+			in3 = in1;
 			handled = true;
 		} else if (input11 instanceof BinaryOp ) {
 			BinaryOp b11 = (BinaryOp)input11;
@@ -662,8 +657,7 @@ public class AggUnaryOp extends Hop implements MultiThreadedHop
 			case POW: // A*A*B case
 				Hop b112 = b11.getInput().get(1);
 				if ( !(input12 instanceof BinaryOp && ((BinaryOp)input12).getOp()==OpOp2.MULT)
-						&& b112 instanceof LiteralOp
-						&& ((LiteralOp)b112).getLongValue() == 2) {
+						&& HopRewriteUtils.isLiteralOfValue(b112, 2) ) {
 					in1 = b11.getInput().get(0).constructLops();
 					in2 = in1;
 					in3 = input12.constructLops();
@@ -682,8 +676,7 @@ public class AggUnaryOp extends Hop implements MultiThreadedHop
 				break;
 			case POW: // A*B*B case
 				Hop b112 = b12.getInput().get(1);
-				if ( b112 instanceof LiteralOp
-						&& ((LiteralOp)b112).getLongValue() == 2) {
+				if ( HopRewriteUtils.isLiteralOfValue(b112, 2) ) {
 					in1 = b12.getInput().get(0).constructLops();
 					in2 = in1;
 					in3 = input11.constructLops();

http://git-wip-us.apache.org/repos/asf/systemml/blob/b94557fd/src/main/java/org/apache/sysml/hops/OptimizerUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/OptimizerUtils.java b/src/main/java/org/apache/sysml/hops/OptimizerUtils.java
index 2a76d07..79b7ee6 100644
--- a/src/main/java/org/apache/sysml/hops/OptimizerUtils.java
+++ b/src/main/java/org/apache/sysml/hops/OptimizerUtils.java
@@ -111,11 +111,6 @@ public class OptimizerUtils
 	public static boolean ALLOW_CONSTANT_FOLDING = true;
 	
 	public static boolean ALLOW_ALGEBRAIC_SIMPLIFICATION = true;
-	/**
-	 * Enables rewriting chains of element-wise multiplies that contain the same multiplicand more than once, as in
-	 * `A*B*A ==> (A^2)*B`.
-	 */
-	public static boolean ALLOW_EMULT_CHAIN_REWRITE = true;
 	public static boolean ALLOW_OPERATOR_FUSION = true;
 	
 	/**

http://git-wip-us.apache.org/repos/asf/systemml/blob/b94557fd/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
index 17ac4ec..8f71359 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
@@ -251,8 +251,7 @@ public class HopRewriteUtils
 	 * @return replacement
 	 */
 	public static Hop replaceHop(final Hop old, final Hop replacement) {
-		final ArrayList<Hop> rootParents = old.getParent();
-		if (rootParents.isEmpty())
+		if (old.getParent().isEmpty())
 			return replacement; // new old!
 		HopRewriteUtils.rewireAllParentChildReferences(old, replacement);
 		return replacement;

http://git-wip-us.apache.org/repos/asf/systemml/blob/b94557fd/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
index b6aab38..1053850 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
@@ -96,8 +96,8 @@ public class ProgramRewriter
 			_dagRuleSet.add(     new RewriteRemoveUnnecessaryCasts()             );		
 			if( OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION )
 				_dagRuleSet.add( new RewriteCommonSubexpressionElimination()     );
-			if ( OptimizerUtils.ALLOW_EMULT_CHAIN_REWRITE )
-				_dagRuleSet.add( new RewriteEMult()                              ); //dependency: cse
+			if ( OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES)
+				_dagRuleSet.add( new RewriteElementwiseMultChainOptimization()                              ); //dependency: cse
 			if( OptimizerUtils.ALLOW_CONSTANT_FOLDING )
 				_dagRuleSet.add( new RewriteConstantFolding()                    ); //dependency: cse
 			if( OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION )

http://git-wip-us.apache.org/repos/asf/systemml/blob/b94557fd/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index 166af2f..91c5972 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -53,6 +53,8 @@ import org.apache.sysml.parser.DataExpression;
 import org.apache.sysml.parser.Expression.DataType;
 import org.apache.sysml.parser.Expression.ValueType;
 
+import static org.apache.sysml.hops.OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES;
+
 /**
  * Rule: Algebraic Simplifications. Simplifies binary expressions
  * in terms of two major purposes: (1) rewrite binary operations
@@ -2051,12 +2053,14 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
 					&& hi2.getInput().get(0).getDim2()==1 && hi2.getInput().get(1).getDim2()==1
 					&& !HopRewriteUtils.isBinary(hi2.getInput().get(0), OpOp2.MULT)
 					&& !HopRewriteUtils.isBinary(hi2.getInput().get(1), OpOp2.MULT)
-					&& !(HopRewriteUtils.isBinary(hi2.getInput().get(0), OpOp2.POW)      // do not rewrite (A^2)*B
-						&& hi2.getInput().get(0).getInput().get(1) instanceof LiteralOp  // let tak+* handle it
-						&& ((LiteralOp)hi2.getInput().get(0).getInput().get(1)).getLongValue() == 2)
-					&& !(HopRewriteUtils.isBinary(hi2.getInput().get(1), OpOp2.POW)      // do not rewrite B*(A^2)
-						&& hi2.getInput().get(1).getInput().get(1) instanceof LiteralOp  // let tak+* handle it
-						&& ((LiteralOp)hi2.getInput().get(1).getInput().get(1)).getLongValue() == 2)
+					&& ( !ALLOW_SUM_PRODUCT_REWRITES
+						|| !(  HopRewriteUtils.isBinary(hi2.getInput().get(0), OpOp2.POW)     // do not rewrite (A^2)*B
+							&& hi2.getInput().get(0).getInput().get(1) instanceof LiteralOp   // let tak+* handle it
+							&& ((LiteralOp)hi2.getInput().get(0).getInput().get(1)).getLongValue() == 2 ))
+					&& ( !ALLOW_SUM_PRODUCT_REWRITES
+						|| !( HopRewriteUtils.isBinary(hi2.getInput().get(1), OpOp2.POW)      // do not rewrite B*(A^2)
+							&& hi2.getInput().get(1).getInput().get(1) instanceof LiteralOp   // let tak+* handle it
+							&& ((LiteralOp)hi2.getInput().get(1).getInput().get(1)).getLongValue() == 2 ))
 					)
 			{
 				baLeft = hi2.getInput().get(0);

http://git-wip-us.apache.org/repos/asf/systemml/blob/b94557fd/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
deleted file mode 100644
index 5cd1471..0000000
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
+++ /dev/null
@@ -1,284 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysml.hops.rewrite;
-
-import java.util.ArrayList;
-import java.util.Comparator;
-import java.util.HashSet;
-import java.util.Iterator;
-import java.util.Map;
-import java.util.Set;
-import java.util.SortedMap;
-import java.util.TreeMap;
-
-import org.apache.sysml.hops.BinaryOp;
-import org.apache.sysml.hops.Hop;
-import org.apache.sysml.hops.HopsException;
-import org.apache.sysml.hops.LiteralOp;
-import org.apache.sysml.parser.Expression;
-
-import com.google.common.collect.HashMultiset;
-import com.google.common.collect.Multiset;
-
-/**
- * Prerequisite: RewriteCommonSubexpressionElimination must run before this rule.
- *
- * Rewrite a chain of element-wise multiply hops that contain identical elements.
- * For example `(B * A) * B` is rewritten to `A * (B^2)` (or `(B^2) * A`), where `^` is element-wise power.
- * The order of the multiplicands depends on their data types, dimentions (matrix or vector), and sparsity.
- *
- * Does not rewrite in the presence of foreign parents in the middle of the e-wise multiply chain,
- * since foreign parents may rely on the individual results.
- */
-public class RewriteEMult extends HopRewriteRule {
-	@Override
-	public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) throws HopsException {
-		if( roots == null )
-			return null;
-		for( int i=0; i<roots.size(); i++ ) {
-			Hop h = roots.get(i);
-			roots.set(i, rule_RewriteEMult(h));
-		}
-		return roots;
-	}
-
-	@Override
-	public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) throws HopsException {
-		if( root == null )
-			return null;
-		return rule_RewriteEMult(root);
-	}
-
-	private static boolean isBinaryMult(final Hop hop) {
-		return hop instanceof BinaryOp && ((BinaryOp)hop).getOp() == Hop.OpOp2.MULT;
-	}
-
-	private static Hop rule_RewriteEMult(final Hop root) {
-		if (root.isVisited())
-			return root;
-		root.setVisited();
-
-		// 1. Find immediate subtree of EMults.
-		if (isBinaryMult(root)) {
-			final Hop left = root.getInput().get(0), right = root.getInput().get(1);
-			final Set<BinaryOp> emults = new HashSet<>();
-			final Multiset<Hop> leaves = HashMultiset.create();
-			findEMultsAndLeaves((BinaryOp)root, emults, leaves);
-
-			// 2. Ensure it is profitable to do a rewrite.
-			if (isOptimizable(emults, leaves)) {
-				// 3. Check for foreign parents.
-				// A foreign parent is a parent of some EMult that is not in the set.
-				// Foreign parents destroy correctness of this rewrite.
-				final boolean okay = (!isBinaryMult(left) || checkForeignParent(emults, (BinaryOp)left)) &&
-						(!isBinaryMult(right) || checkForeignParent(emults, (BinaryOp)right));
-				if (okay) {
-					// 4. Construct replacement EMults for the leaves
-					final Hop replacement = constructReplacement(leaves);
-					if (LOG.isDebugEnabled())
-						LOG.debug(String.format(
-								"Element-wise multiply chain rewrite of %d e-mults at sub-dag %d to new sub-dag %d",
-								emults.size(), root.getHopID(), replacement.getHopID()));
-
-					// 5. Replace root with replacement
-					final Hop newRoot = HopRewriteUtils.replaceHop(root, replacement);
-
-					// 6. Recurse at leaves (no need to repeat the interior emults)
-					for (final Hop leaf : leaves.elementSet()) {
-						recurseInputs(leaf);
-					}
-					return newRoot;
-				}
-			}
-		}
-		// This rewrite is not applicable to the current root.
-		// Try the root's children.
-		recurseInputs(root);
-		return root;
-	}
-
-	private static void recurseInputs(final Hop parent) {
-		final ArrayList<Hop> inputs = parent.getInput();
-		for (int i = 0; i < inputs.size(); i++) {
-			final Hop input = inputs.get(i);
-			final Hop newInput = rule_RewriteEMult(input);
-			inputs.set(i, newInput);
-		}
-	}
-
-	private static Hop constructReplacement(final Multiset<Hop> leaves) {
-		// Sort by data type
-		final SortedMap<Hop,Integer> sorted = new TreeMap<>(compareByDataType);
-		for (final Multiset.Entry<Hop> entry : leaves.entrySet()) {
-			final Hop h = entry.getElement();
-			// unlink parents (the EMults, which we are throwing away)
-			h.getParent().clear();
-			sorted.put(h, entry.getCount());
-		}
-		// sorted contains all leaves, sorted by data type, stripped from their parents
-
-		// Construct right-deep EMult tree
-		final Iterator<Map.Entry<Hop, Integer>> iterator = sorted.entrySet().iterator();
-		Hop first = constructPower(iterator.next());
-
-		for (int i = 1; i < sorted.size(); i++) {
-			final Hop second = constructPower(iterator.next());
-			first = HopRewriteUtils.createBinary(second, first, Hop.OpOp2.MULT);
-			first.setVisited();
-		}
-		return first;
-	}
-
-	private static Hop constructPower(Map.Entry<Hop, Integer> entry) {
-		final Hop hop = entry.getKey();
-		final int cnt = entry.getValue();
-		assert(cnt >= 1);
-		if (cnt == 1)
-			return hop; // don't set this visited... we will visit this next
-		Hop pow = HopRewriteUtils.createBinary(hop, new LiteralOp(cnt), Hop.OpOp2.POW);
-		pow.setVisited();
-		return pow;
-	}
-
-	/**
-	 * A Comparator that orders Hops by their data type, dimention, and sparsity.
-	 * The order is as follows:
-	 * 		scalars > row vectors > col vectors >
-	 *      non-vector matrices ordered by sparsity (higher nnz first, unknown sparsity last) >
-	 *      other data types.
-	 * Disambiguate by Hop ID.
-	 */
-	private static final Comparator<Hop> compareByDataType = new Comparator<Hop>() {
-		@Override
-		public final int compare(Hop o1, Hop o2) {
-			int c = Integer.compare(orderDataType[o1.getDataType().ordinal()], orderDataType[o2.getDataType().ordinal()]);
-			if (c != 0) return c;
-
-			// o1 and o2 have the same data type
-			switch (o1.getDataType()) {
-			case SCALAR: return Long.compare(o1.getHopID(), o2.getHopID());
-			case MATRIX:
-				// two matrices; check for vectors
-				if (o1.getDim1() == 1) { // row vector
-						if (o2.getDim1() != 1) return 1; // row vectors are greatest of matrices
-						return compareBySparsityThenId(o1, o2); // both row vectors
-				} else if (o2.getDim1() == 1) { // 2 is row vector; 1 is not
-						return -1; // row vectors are the greatest matrices
-				} else if (o1.getDim2() == 1) { // col vector
-						if (o2.getDim2() != 1) return 1; // col vectors greater than non-vectors
-						return compareBySparsityThenId(o1, o2); // both col vectors
-				} else if (o2.getDim2() == 1) { // 2 is col vector; 1 is not
-						return 1; // col vectors greater than non-vectors
-				} else { // both non-vectors
-						return compareBySparsityThenId(o1, o2);
-				}
-			default:
-				return Long.compare(o1.getHopID(), o2.getHopID());
-			}
-		}
-		private int compareBySparsityThenId(Hop o1, Hop o2) {
-			// the hop with more nnz is first; unknown nnz (-1) last
-			int c = Long.compare(o1.getNnz(), o2.getNnz());
-			if (c != 0) return c;
-			return Long.compare(o1.getHopID(), o2.getHopID());
-		}
-		private final int[] orderDataType;
-		{
-			Expression.DataType[] dtValues = Expression.DataType.values();
-			orderDataType = new int[dtValues.length];
-			for (int i = 0, valuesLength = dtValues.length; i < valuesLength; i++) {
-				switch(dtValues[i]) {
-				case SCALAR:
-					orderDataType[i] = 4;
-					break;
-				case MATRIX:
-					orderDataType[i] = 3;
-					break;
-				case FRAME:
-					orderDataType[i] = 2;
-					break;
-				case OBJECT:
-					orderDataType[i] = 1;
-					break;
-				case UNKNOWN:
-					orderDataType[i] = 0;
-					break;
-				}
-			}
-		}
-	};
-
-	/**
-	 * Check if a node has a parent that is not in the set of emults. Recursively check children who are also emults.
-	 * @param emults The set of BinaryOp element-wise multiply hops in the emult chain.
-	 * @param child An interior emult hop in the emult chain dag.
-	 * @return Whether this interior emult or any child emult has a foreign parent.
-	 */
-	private static boolean checkForeignParent(final Set<BinaryOp> emults, final BinaryOp child) {
-		final ArrayList<Hop> parents = child.getParent();
-		if (parents.size() > 1)
-			for (final Hop parent : parents)
-				//noinspection SuspiciousMethodCalls
-				if (!emults.contains(parent))
-					return false;
-		// child does not have foreign parents
-
-		final ArrayList<Hop> inputs = child.getInput();
-		final Hop left = inputs.get(0), right = inputs.get(1);
-		return  (!isBinaryMult(left) || checkForeignParent(emults, (BinaryOp)left)) &&
-				(!isBinaryMult(right) || checkForeignParent(emults, (BinaryOp)right));
-	}
-
-	/**
-	 * Create a set of the counts of all BinaryOp MULTs in the immediate subtree, starting with root, recursively.
-	 * @param root Root of sub-dag
-	 * @param emults Out parameter. The set of BinaryOp element-wise multiply hops in the emult chain (including root).
-	 * @param leaves Out parameter. The multiset of multiplicands in the emult chain.
-	 */
-	private static void findEMultsAndLeaves(final BinaryOp root, final Set<BinaryOp> emults, final Multiset<Hop> leaves) {
-		// Because RewriteCommonSubexpressionElimination already ran, it is safe to compare by equality.
-		emults.add(root);
-
-		final ArrayList<Hop> inputs = root.getInput();
-		final Hop left = inputs.get(0), right = inputs.get(1);
-
-		if (isBinaryMult(left)) findEMultsAndLeaves((BinaryOp) left, emults, leaves);
-		else leaves.add(left);
-
-		if (isBinaryMult(right)) findEMultsAndLeaves((BinaryOp) right, emults, leaves);
-		else leaves.add(right);
-	}
-
-	/**
-	 * Only optimize a subtree of emults if there are at least two emults.
-	 * @param emults The set of BinaryOp element-wise multiply hops in the emult chain.
-	 * @param leaves The multiset of multiplicands in the emult chain.
-	 * @return If the multiset is worth optimizing.
-	 */
-	private static boolean isOptimizable(Set<BinaryOp> emults, final Multiset<Hop> leaves) {
-		// Old criterion: there should be at least one repeated leaf
-//		for (Multiset.Entry<Hop> hopEntry : leaves.entrySet()) {
-//			if (hopEntry.getCount() > 1)
-//				return true;
-//		}
-//		return false;
-		return emults.size() >= 2;
-	}
-}

http://git-wip-us.apache.org/repos/asf/systemml/blob/b94557fd/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
new file mode 100644
index 0000000..bd873ff
--- /dev/null
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
@@ -0,0 +1,281 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.hops.rewrite;
+
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.Map;
+import java.util.Set;
+import java.util.SortedMap;
+import java.util.TreeMap;
+
+import org.apache.sysml.hops.BinaryOp;
+import org.apache.sysml.hops.Hop;
+import org.apache.sysml.hops.HopsException;
+import org.apache.sysml.hops.LiteralOp;
+import org.apache.sysml.parser.Expression;
+
+import com.google.common.collect.HashMultiset;
+import com.google.common.collect.Multiset;
+
+/**
+ * Prerequisite: RewriteCommonSubexpressionElimination must run before this rule.
+ *
+ * Rewrite a chain of element-wise multiply hops that contain identical elements.
+ * For example `(B * A) * B` is rewritten to `A * (B^2)` (or `(B^2) * A`), where `^` is element-wise power.
+ * The order of the multiplicands depends on their data types, dimentions (matrix or vector), and sparsity.
+ *
+ * Does not rewrite in the presence of foreign parents in the middle of the e-wise multiply chain,
+ * since foreign parents may rely on the individual results.
+ */
+public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
+	@Override
+	public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) throws HopsException {
+		if( roots == null )
+			return null;
+		for( int i=0; i<roots.size(); i++ ) {
+			Hop h = roots.get(i);
+			roots.set(i, rule_RewriteEMult(h));
+		}
+		return roots;
+	}
+
+	@Override
+	public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) throws HopsException {
+		if( root == null )
+			return null;
+		return rule_RewriteEMult(root);
+	}
+
+	private static boolean isBinaryMult(final Hop hop) {
+		return hop instanceof BinaryOp && ((BinaryOp)hop).getOp() == Hop.OpOp2.MULT;
+	}
+
+	private static Hop rule_RewriteEMult(final Hop root) {
+		if (root.isVisited())
+			return root;
+		root.setVisited();
+
+		// 1. Find immediate subtree of EMults.
+		if (isBinaryMult(root)) {
+			final Hop left = root.getInput().get(0), right = root.getInput().get(1);
+			final Set<BinaryOp> emults = new HashSet<>();
+			final Multiset<Hop> leaves = HashMultiset.create();
+			findEMultsAndLeaves((BinaryOp)root, emults, leaves);
+
+			// 2. Ensure it is profitable to do a rewrite.
+			if (isOptimizable(emults, leaves)) {
+				// 3. Check for foreign parents.
+				// A foreign parent is a parent of some EMult that is not in the set.
+				// Foreign parents destroy correctness of this rewrite.
+				final boolean okay = (!isBinaryMult(left) || checkForeignParent(emults, (BinaryOp)left)) &&
+						(!isBinaryMult(right) || checkForeignParent(emults, (BinaryOp)right));
+				if (okay) {
+					// 4. Construct replacement EMults for the leaves
+					final Hop replacement = constructReplacement(leaves);
+					if (LOG.isDebugEnabled())
+						LOG.debug(String.format(
+								"Element-wise multiply chain rewrite of %d e-mults at sub-dag %d to new sub-dag %d",
+								emults.size(), root.getHopID(), replacement.getHopID()));
+
+					// 5. Replace root with replacement
+					final Hop newRoot = HopRewriteUtils.replaceHop(root, replacement);
+
+					// 6. Recurse at leaves (no need to repeat the interior emults)
+					for (final Hop leaf : leaves.elementSet()) {
+						recurseInputs(leaf);
+					}
+					return newRoot;
+				}
+			}
+		}
+		// This rewrite is not applicable to the current root.
+		// Try the root's children.
+		recurseInputs(root);
+		return root;
+	}
+
+	private static void recurseInputs(final Hop parent) {
+		final ArrayList<Hop> inputs = parent.getInput();
+		for (int i = 0; i < inputs.size(); i++) {
+			final Hop input = inputs.get(i);
+			final Hop newInput = rule_RewriteEMult(input);
+			inputs.set(i, newInput);
+		}
+	}
+
+	private static Hop constructReplacement(final Multiset<Hop> leaves) {
+		// Sort by data type
+		final SortedMap<Hop,Integer> sorted = new TreeMap<>(compareByDataType);
+		for (final Multiset.Entry<Hop> entry : leaves.entrySet()) {
+			final Hop h = entry.getElement();
+			// unlink parents (the EMults, which we are throwing away)
+			h.getParent().clear();
+			sorted.put(h, entry.getCount());
+		}
+		// sorted contains all leaves, sorted by data type, stripped from their parents
+
+		// Construct right-deep EMult tree
+		final Iterator<Map.Entry<Hop, Integer>> iterator = sorted.entrySet().iterator();
+		Hop first = constructPower(iterator.next());
+
+		for (int i = 1; i < sorted.size(); i++) {
+			final Hop second = constructPower(iterator.next());
+			first = HopRewriteUtils.createBinary(second, first, Hop.OpOp2.MULT);
+			first.setVisited();
+		}
+		return first;
+	}
+
+	private static Hop constructPower(Map.Entry<Hop, Integer> entry) {
+		final Hop hop = entry.getKey();
+		final int cnt = entry.getValue();
+		assert(cnt >= 1);
+		if (cnt == 1)
+			return hop; // don't set this visited... we will visit this next
+		Hop pow = HopRewriteUtils.createBinary(hop, new LiteralOp(cnt), Hop.OpOp2.POW);
+		pow.setVisited();
+		return pow;
+	}
+
+	/**
+	 * A Comparator that orders Hops by their data type, dimention, and sparsity.
+	 * The order is as follows:
+	 * 		scalars > row vectors > col vectors >
+	 *      non-vector matrices ordered by sparsity (higher nnz first, unknown sparsity last) >
+	 *      other data types.
+	 * Disambiguate by Hop ID.
+	 */
+	private static final Comparator<Hop> compareByDataType = new Comparator<Hop>() {
+		@Override
+		public final int compare(Hop o1, Hop o2) {
+			int c = Integer.compare(orderDataType[o1.getDataType().ordinal()], orderDataType[o2.getDataType().ordinal()]);
+			if (c != 0) return c;
+
+			// o1 and o2 have the same data type
+			switch (o1.getDataType()) {
+			case MATRIX:
+				// two matrices; check for vectors
+				if (o1.getDim1() == 1) { // row vector
+						if (o2.getDim1() != 1) return 1; // row vectors are greatest of matrices
+						return compareBySparsityThenId(o1, o2); // both row vectors
+				} else if (o2.getDim1() == 1) { // 2 is row vector; 1 is not
+						return -1; // row vectors are the greatest matrices
+				} else if (o1.getDim2() == 1) { // col vector
+						if (o2.getDim2() != 1) return 1; // col vectors greater than non-vectors
+						return compareBySparsityThenId(o1, o2); // both col vectors
+				} else if (o2.getDim2() == 1) { // 2 is col vector; 1 is not
+						return 1; // col vectors greater than non-vectors
+				} else { // both non-vectors
+						return compareBySparsityThenId(o1, o2);
+				}
+			default:
+				return Long.compare(o1.getHopID(), o2.getHopID());
+			}
+		}
+		private int compareBySparsityThenId(Hop o1, Hop o2) {
+			// the hop with more nnz is first; unknown nnz (-1) last
+			int c = Long.compare(o1.getNnz(), o2.getNnz());
+			if (c != 0) return c;
+			return Long.compare(o1.getHopID(), o2.getHopID());
+		}
+		private final int[] orderDataType;
+		{
+			Expression.DataType[] dtValues = Expression.DataType.values();
+			orderDataType = new int[dtValues.length];
+			for (int i = 0, valuesLength = dtValues.length; i < valuesLength; i++) {
+				switch(dtValues[i]) {
+				case SCALAR:
+					orderDataType[i] = 4;
+					break;
+				case MATRIX:
+					orderDataType[i] = 3;
+					break;
+				case FRAME:
+					orderDataType[i] = 2;
+					break;
+				case OBJECT:
+					orderDataType[i] = 1;
+					break;
+				case UNKNOWN:
+					orderDataType[i] = 0;
+					break;
+				}
+			}
+		}
+	};
+
+	/**
+	 * Check if a node has a parent that is not in the set of emults. Recursively check children who are also emults.
+	 * @param emults The set of BinaryOp element-wise multiply hops in the emult chain.
+	 * @param child An interior emult hop in the emult chain dag.
+	 * @return Whether this interior emult or any child emult has a foreign parent.
+	 */
+	private static boolean checkForeignParent(final Set<BinaryOp> emults, final BinaryOp child) {
+		final ArrayList<Hop> parents = child.getParent();
+		if (parents.size() > 1)
+			for (final Hop parent : parents)
+				//noinspection SuspiciousMethodCalls (for Intellij, which checks when
+				if (!emults.contains(parent))
+					return false;
+		// child does not have foreign parents
+
+		final ArrayList<Hop> inputs = child.getInput();
+		final Hop left = inputs.get(0), right = inputs.get(1);
+		return  (!isBinaryMult(left) || checkForeignParent(emults, (BinaryOp)left)) &&
+				(!isBinaryMult(right) || checkForeignParent(emults, (BinaryOp)right));
+	}
+
+	/**
+	 * Create a set of the counts of all BinaryOp MULTs in the immediate subtree, starting with root, recursively.
+	 * @param root Root of sub-dag
+	 * @param emults Out parameter. The set of BinaryOp element-wise multiply hops in the emult chain (including root).
+	 * @param leaves Out parameter. The multiset of multiplicands in the emult chain.
+	 */
+	private static void findEMultsAndLeaves(final BinaryOp root, final Set<BinaryOp> emults, final Multiset<Hop> leaves) {
+		// Because RewriteCommonSubexpressionElimination already ran, it is safe to compare by equality.
+		emults.add(root);
+
+		final ArrayList<Hop> inputs = root.getInput();
+		final Hop left = inputs.get(0), right = inputs.get(1);
+
+		if (isBinaryMult(left))
+			findEMultsAndLeaves((BinaryOp) left, emults, leaves);
+		else
+			leaves.add(left);
+
+		if (isBinaryMult(right))
+			findEMultsAndLeaves((BinaryOp) right, emults, leaves);
+		else
+			leaves.add(right);
+	}
+
+	/**
+	 * Only optimize a subtree of emults if there are at least two emults (meaning, at least 3 multiplicands).
+	 * @param emults The set of BinaryOp element-wise multiply hops in the emult chain.
+	 * @param leaves The multiset of multiplicands in the emult chain.
+	 * @return If the multiset is worth optimizing.
+	 */
+	private static boolean isOptimizable(Set<BinaryOp> emults, final Multiset<Hop> leaves) {
+		return emults.size() >= 2;
+	}
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/b94557fd/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java
deleted file mode 100644
index 85dbea4..0000000
--- a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java
+++ /dev/null
@@ -1,127 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- * 
- *   http://www.apache.org/licenses/LICENSE-2.0
- * 
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysml.test.integration.functions.misc;
-
-import java.util.HashMap;
-
-import org.apache.sysml.api.DMLScript;
-import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
-import org.apache.sysml.hops.OptimizerUtils;
-import org.apache.sysml.lops.LopProperties.ExecType;
-import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex;
-import org.apache.sysml.test.integration.AutomatedTestBase;
-import org.apache.sysml.test.integration.TestConfiguration;
-import org.apache.sysml.test.utils.TestUtils;
-import org.junit.Assert;
-import org.junit.Test;
-
-/**
- * Test whether `A*B*A` successfully rewrites to `(A^2)*B`.
- */
-public class RewriteEMultChainTest extends AutomatedTestBase
-{
-	private static final String TEST_NAME1 = "RewriteEMultChainOpXYX";
-	private static final String TEST_DIR = "functions/misc/";
-	private static final String TEST_CLASS_DIR = TEST_DIR + RewriteEMultChainTest.class.getSimpleName() + "/";
-	
-	private static final int rows = 123;
-	private static final int cols = 321;
-	private static final double eps = Math.pow(10, -10);
-	
-	@Override
-	public void setUp() {
-		TestUtils.clearAssertionInformation();
-		addTestConfiguration( TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) );
-	}
-
-	@Test
-	public void testMatrixMultChainOptNoRewritesCP() {
-		testRewriteMatrixMultChainOp(TEST_NAME1, false, ExecType.CP);
-	}
-	
-	@Test
-	public void testMatrixMultChainOptNoRewritesSP() {
-		testRewriteMatrixMultChainOp(TEST_NAME1, false, ExecType.SPARK);
-	}
-	
-	@Test
-	public void testMatrixMultChainOptRewritesCP() {
-		testRewriteMatrixMultChainOp(TEST_NAME1, true, ExecType.CP);
-	}
-	
-	@Test
-	public void testMatrixMultChainOptRewritesSP() {
-		testRewriteMatrixMultChainOp(TEST_NAME1, true, ExecType.SPARK);
-	}
-
-	private void testRewriteMatrixMultChainOp(String testname, boolean rewrites, ExecType et)
-	{	
-		RUNTIME_PLATFORM platformOld = rtplatform;
-		switch( et ){
-			case MR: rtplatform = RUNTIME_PLATFORM.HADOOP; break;
-			case SPARK: rtplatform = RUNTIME_PLATFORM.SPARK; break;
-			default: rtplatform = RUNTIME_PLATFORM.HYBRID_SPARK; break;
-		}
-		
-		boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
-		if( rtplatform == RUNTIME_PLATFORM.SPARK )
-			DMLScript.USE_LOCAL_SPARK_CONFIG = true;
-		
-		boolean rewritesOld = OptimizerUtils.ALLOW_EMULT_CHAIN_REWRITE;
-		OptimizerUtils.ALLOW_EMULT_CHAIN_REWRITE = rewrites;
-		
-		try
-		{
-			TestConfiguration config = getTestConfiguration(testname);
-			loadTestConfiguration(config);
-			
-			String HOME = SCRIPT_DIR + TEST_DIR;
-			fullDMLScriptName = HOME + testname + ".dml";
-			programArgs = new String[] { "-explain", "hops", "-stats", "-args", input("X"), input("Y"), output("R") };
-			fullRScriptName = HOME + testname + ".R";
-			rCmd = getRCmd(inputDir(), expectedDir());			
-
-			double Xsparsity = 0.8, Ysparsity = 0.6;
-			double[][] X = getRandomMatrix(rows, cols, -1, 1, Xsparsity, 7);
-			double[][] Y = getRandomMatrix(rows, cols, -1, 1, Ysparsity, 3);
-			writeInputMatrixWithMTD("X", X, true);
-			writeInputMatrixWithMTD("Y", Y, true);
-
-			//execute tests
-			runTest(true, false, null, -1); 
-			runRScript(true); 
-			
-			//compare matrices 
-			HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromHDFS("R");
-			HashMap<CellIndex, Double> rfile  = readRMatrixFromFS("R");
-			TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R");
-			
-			//check for presence of power operator, if we did a rewrite
-			if( rewrites ) {
-				Assert.assertTrue(heavyHittersContainsSubString("^2"));
-			}
-		}
-		finally {
-			OptimizerUtils.ALLOW_EMULT_CHAIN_REWRITE = rewritesOld;
-			rtplatform = platformOld;
-			DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
-		}
-	}
-}

http://git-wip-us.apache.org/repos/asf/systemml/blob/b94557fd/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationChainTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationChainTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationChainTest.java
new file mode 100644
index 0000000..47b2f0e
--- /dev/null
+++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationChainTest.java
@@ -0,0 +1,127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.test.integration.functions.misc;
+
+import java.util.HashMap;
+
+import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
+import org.apache.sysml.hops.OptimizerUtils;
+import org.apache.sysml.lops.LopProperties.ExecType;
+import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysml.test.integration.AutomatedTestBase;
+import org.apache.sysml.test.integration.TestConfiguration;
+import org.apache.sysml.test.utils.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+
+/**
+ * Test whether `A*B*A` successfully rewrites to `(A^2)*B`.
+ */
+public class RewriteElementwiseMultChainOptimizationChainTest extends AutomatedTestBase
+{
+	private static final String TEST_NAME1 = "RewriteEMultChainOpXYX";
+	private static final String TEST_DIR = "functions/misc/";
+	private static final String TEST_CLASS_DIR = TEST_DIR + RewriteElementwiseMultChainOptimizationChainTest.class.getSimpleName() + "/";
+	
+	private static final int rows = 123;
+	private static final int cols = 321;
+	private static final double eps = Math.pow(10, -10);
+	
+	@Override
+	public void setUp() {
+		TestUtils.clearAssertionInformation();
+		addTestConfiguration( TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) );
+	}
+
+	@Test
+	public void testMatrixMultChainOptNoRewritesCP() {
+		testRewriteMatrixMultChainOp(TEST_NAME1, false, ExecType.CP);
+	}
+	
+	@Test
+	public void testMatrixMultChainOptNoRewritesSP() {
+		testRewriteMatrixMultChainOp(TEST_NAME1, false, ExecType.SPARK);
+	}
+	
+	@Test
+	public void testMatrixMultChainOptRewritesCP() {
+		testRewriteMatrixMultChainOp(TEST_NAME1, true, ExecType.CP);
+	}
+	
+	@Test
+	public void testMatrixMultChainOptRewritesSP() {
+		testRewriteMatrixMultChainOp(TEST_NAME1, true, ExecType.SPARK);
+	}
+
+	private void testRewriteMatrixMultChainOp(String testname, boolean rewrites, ExecType et)
+	{	
+		RUNTIME_PLATFORM platformOld = rtplatform;
+		switch( et ){
+			case MR: rtplatform = RUNTIME_PLATFORM.HADOOP; break;
+			case SPARK: rtplatform = RUNTIME_PLATFORM.SPARK; break;
+			default: rtplatform = RUNTIME_PLATFORM.HYBRID_SPARK; break;
+		}
+		
+		boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+		if( rtplatform == RUNTIME_PLATFORM.SPARK )
+			DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+		
+		boolean rewritesOld = OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES;
+		OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewrites;
+		
+		try
+		{
+			TestConfiguration config = getTestConfiguration(testname);
+			loadTestConfiguration(config);
+			
+			String HOME = SCRIPT_DIR + TEST_DIR;
+			fullDMLScriptName = HOME + testname + ".dml";
+			programArgs = new String[] { "-explain", "hops", "-stats", "-args", input("X"), input("Y"), output("R") };
+			fullRScriptName = HOME + testname + ".R";
+			rCmd = getRCmd(inputDir(), expectedDir());			
+
+			double Xsparsity = 0.8, Ysparsity = 0.6;
+			double[][] X = getRandomMatrix(rows, cols, -1, 1, Xsparsity, 7);
+			double[][] Y = getRandomMatrix(rows, cols, -1, 1, Ysparsity, 3);
+			writeInputMatrixWithMTD("X", X, true);
+			writeInputMatrixWithMTD("Y", Y, true);
+
+			//execute tests
+			runTest(true, false, null, -1); 
+			runRScript(true); 
+			
+			//compare matrices 
+			HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromHDFS("R");
+			HashMap<CellIndex, Double> rfile  = readRMatrixFromFS("R");
+			TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R");
+			
+			//check for presence of power operator, if we did a rewrite
+			if( rewrites ) {
+				Assert.assertTrue(heavyHittersContainsSubString("^2"));
+			}
+		}
+		finally {
+			OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewritesOld;
+			rtplatform = platformOld;
+			DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+		}
+	}
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/b94557fd/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java b/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java
index 1829bf0..460829d 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java
@@ -24,6 +24,7 @@ import java.util.HashMap;
 import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
 import org.apache.sysml.hops.OptimizerUtils;
+import org.apache.sysml.hops.rewrite.RewriteElementwiseMultChainOptimization;
 import org.apache.sysml.lops.LopProperties.ExecType;
 import org.apache.sysml.runtime.instructions.Instruction;
 import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex;
@@ -36,7 +37,7 @@ import org.junit.Test;
 
 /**
  * Similar to {@link TernaryAggregateTest} except that it tests `sum(A*B*A)`.
- * Checks compatibility with {@link org.apache.sysml.hops.rewrite.RewriteEMult}.
+ * Checks compatibility with {@link RewriteElementwiseMultChainOptimization}.
  */
 public class ABATernaryAggregateTest extends AutomatedTestBase
 {
@@ -368,14 +369,14 @@ public class ABATernaryAggregateTest extends AutomatedTestBase
 			DMLScript.USE_LOCAL_SPARK_CONFIG = true;
 	
 		boolean rewritesOld = OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES,
-				rewritesOldEmult = OptimizerUtils.ALLOW_EMULT_CHAIN_REWRITE;
+				rewritesOldEmult = OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES;
 		
 		try {
 			TestConfiguration config = getTestConfiguration(testname);
 			loadTestConfiguration(config);
 			
 			OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewrites;
-			OptimizerUtils.ALLOW_EMULT_CHAIN_REWRITE = rewrites;
+			OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewrites;
 
 			String HOME = SCRIPT_DIR + TEST_DIR;
 			fullDMLScriptName = HOME + testname + ".dml";
@@ -411,7 +412,7 @@ public class ABATernaryAggregateTest extends AutomatedTestBase
 			rtplatform = platformOld;
 			DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
 			OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewritesOld;
-			OptimizerUtils.ALLOW_EMULT_CHAIN_REWRITE = rewritesOldEmult;
+			OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewritesOldEmult;
 		}
 	}
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/b94557fd/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
----------------------------------------------------------------------
diff --git a/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java b/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
index e352e6d..deea784 100644
--- a/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
+++ b/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
@@ -50,6 +50,7 @@ import org.junit.runners.Suite;
 	ReadAfterWriteTest.class,
 	RewriteCSETransposeScalarTest.class,
 	RewriteCTableToRExpandTest.class,
+	RewriteElementwiseMultChainOptimizationChainTest.class,
 	RewriteEliminateAggregatesTest.class,
 	RewriteFuseBinaryOpChainTest.class,
 	RewriteFusedRandTest.class,

http://git-wip-us.apache.org/repos/asf/systemml/blob/b94557fd/src/test_suites/java/org/apache/sysml/test/integration/functions/ternary/ZPackageSuite.java
----------------------------------------------------------------------
diff --git a/src/test_suites/java/org/apache/sysml/test/integration/functions/ternary/ZPackageSuite.java b/src/test_suites/java/org/apache/sysml/test/integration/functions/ternary/ZPackageSuite.java
index 784177d..ee14359 100644
--- a/src/test_suites/java/org/apache/sysml/test/integration/functions/ternary/ZPackageSuite.java
+++ b/src/test_suites/java/org/apache/sysml/test/integration/functions/ternary/ZPackageSuite.java
@@ -22,10 +22,11 @@ package org.apache.sysml.test.integration.functions.ternary;
 import org.junit.runner.RunWith;
 import org.junit.runners.Suite;
 
-/** Group together the tests in this package into a single suite so that the Maven build
+/* Group together the tests in this package into a single suite so that the Maven build
  *  won't run two of them at once. */
 @RunWith(Suite.class)
 @Suite.SuiteClasses({
+	ABATernaryAggregateTest.class,
 	CentralMomentWeightsTest.class,
 	CovarianceWeightsTest.class,
 	CTableMatrixIgnoreZerosTest.class,


[05/23] systemml git commit: TernaryAggregate now applies to a power of 3.

Posted by mb...@apache.org.
TernaryAggregate now applies to a power of 3.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/f005d949
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/f005d949
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/f005d949

Branch: refs/heads/master
Commit: f005d94997d9c17ad8e90b4d2bd340f81b9a752d
Parents: 8b832f6
Author: Dylan Hutchison <dh...@cs.washington.edu>
Authored: Fri Jun 9 22:06:10 2017 -0700
Committer: Dylan Hutchison <dh...@cs.washington.edu>
Committed: Sun Jun 18 17:43:24 2017 -0700

----------------------------------------------------------------------
 .../java/org/apache/sysml/hops/AggUnaryOp.java  | 67 ++++++++++++--------
 .../functions/misc/RewriteEMultChainTest.java   |  7 +-
 .../functions/misc/RewriteEMultChainOp.R        | 33 ----------
 .../functions/misc/RewriteEMultChainOp.dml      | 28 --------
 .../functions/misc/RewriteEMultChainOpXYX.R     | 33 ++++++++++
 .../functions/misc/RewriteEMultChainOpXYX.dml   | 28 ++++++++
 6 files changed, 106 insertions(+), 90 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/f005d949/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
index 4573b66..300a20c 100644
--- a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
+++ b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
@@ -490,29 +490,35 @@ public class AggUnaryOp extends Hop implements MultiThreadedHop
 			(_direction == Direction.RowCol || _direction == Direction.Col)  ) 
 		{
 			Hop input1 = getInput().get(0);
-			if( input1.getParent().size() == 1 && //sum single consumer
-				input1 instanceof BinaryOp && ((BinaryOp)input1).getOp()==OpOp2.MULT
-				// As unary agg instruction is not implemented in MR and since MR is in maintenance mode, postponed it.
-				&& input1.optFindExecType() != ExecType.MR) 
-			{
-				Hop input11 = input1.getInput().get(0);
-				Hop input12 = input1.getInput().get(1);
-				
-				if( input11 instanceof BinaryOp && ((BinaryOp)input11).getOp()==OpOp2.MULT ) {
-					//ternary, arbitrary matrices but no mv/outer operations.
-					ret = HopRewriteUtils.isEqualSize(input11.getInput().get(0), input1)
-						&& HopRewriteUtils.isEqualSize(input11.getInput().get(1), input1)	
-						&& HopRewriteUtils.isEqualSize(input12, input1);
-				}
-				else if( input12 instanceof BinaryOp && ((BinaryOp)input12).getOp()==OpOp2.MULT ) {
-					//ternary, arbitrary matrices but no mv/outer operations.
-					ret = HopRewriteUtils.isEqualSize(input12.getInput().get(0), input1)
-							&& HopRewriteUtils.isEqualSize(input12.getInput().get(1), input1)	
-							&& HopRewriteUtils.isEqualSize(input11, input1);
+			if (input1.getParent().size() == 1
+					&& input1 instanceof BinaryOp) { //sum single consumer
+				BinaryOp binput1 = (BinaryOp)input1;
+
+				if (binput1.getOp() == OpOp2.POW
+						&& binput1.getInput().get(1) instanceof LiteralOp) {
+					LiteralOp lit = (LiteralOp)binput1.getInput().get(1);
+					ret = lit.getLongValue() == 3;
 				}
-				else {
-					//binary, arbitrary matrices but no mv/outer operations.
-					ret = HopRewriteUtils.isEqualSize(input11, input12);
+				else if (binput1.getOp() == OpOp2.MULT
+						// As unary agg instruction is not implemented in MR and since MR is in maintenance mode, postponed it.
+						&& input1.optFindExecType() != ExecType.MR) {
+					Hop input11 = input1.getInput().get(0);
+					Hop input12 = input1.getInput().get(1);
+
+					if (input11 instanceof BinaryOp && ((BinaryOp) input11).getOp() == OpOp2.MULT) {
+						//ternary, arbitrary matrices but no mv/outer operations.
+						ret = HopRewriteUtils.isEqualSize(input11.getInput().get(0), input1) && HopRewriteUtils
+								.isEqualSize(input11.getInput().get(1), input1) && HopRewriteUtils
+								.isEqualSize(input12, input1);
+					} else if (input12 instanceof BinaryOp && ((BinaryOp) input12).getOp() == OpOp2.MULT) {
+						//ternary, arbitrary matrices but no mv/outer operations.
+						ret = HopRewriteUtils.isEqualSize(input12.getInput().get(0), input1) && HopRewriteUtils
+								.isEqualSize(input12.getInput().get(1), input1) && HopRewriteUtils
+								.isEqualSize(input11, input1);
+					} else {
+						//binary, arbitrary matrices but no mv/outer operations.
+						ret = HopRewriteUtils.isEqualSize(input11, input12);
+					}
 				}
 			}
 		}
@@ -626,14 +632,25 @@ public class AggUnaryOp extends Hop implements MultiThreadedHop
 	private Lop constructLopsTernaryAggregateRewrite(ExecType et) 
 		throws HopsException, LopsException
 	{
-		Hop input1 = getInput().get(0);
+		BinaryOp input1 = (BinaryOp)getInput().get(0);
 		Hop input11 = input1.getInput().get(0);
 		Hop input12 = input1.getInput().get(1);
 		
 		Lop in1 = null, in2 = null, in3 = null;
 		boolean handled = false;
-		
-		if( input11 instanceof BinaryOp ) {
+
+		if (input1.getOp() == OpOp2.POW) {
+			switch ((int)((LiteralOp)input12).getLongValue()) {
+			case 3:
+				in1 = input11.constructLops();
+				in2 = in1;
+				in3 = in1;
+				break;
+			default:
+				throw new AssertionError("unreachable; only applies to power 3");
+			}
+			handled = true;
+		} else if (input11 instanceof BinaryOp ) {
 			BinaryOp b11 = (BinaryOp)input11;
 			switch (b11.getOp()) {
 			case MULT: // A*B*C case

http://git-wip-us.apache.org/repos/asf/systemml/blob/f005d949/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java
index 18ed55d..85dbea4 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java
@@ -37,7 +37,7 @@ import org.junit.Test;
  */
 public class RewriteEMultChainTest extends AutomatedTestBase
 {
-	private static final String TEST_NAME1 = "RewriteEMultChainOp";
+	private static final String TEST_NAME1 = "RewriteEMultChainOpXYX";
 	private static final String TEST_DIR = "functions/misc/";
 	private static final String TEST_CLASS_DIR = TEST_DIR + RewriteEMultChainTest.class.getSimpleName() + "/";
 	
@@ -94,8 +94,7 @@ public class RewriteEMultChainTest extends AutomatedTestBase
 			
 			String HOME = SCRIPT_DIR + TEST_DIR;
 			fullDMLScriptName = HOME + testname + ".dml";
-			programArgs = new String[]{ "-explain", "hops", "-stats", 
-				"-args", input("X"), input("Y"), output("R") };
+			programArgs = new String[] { "-explain", "hops", "-stats", "-args", input("X"), input("Y"), output("R") };
 			fullRScriptName = HOME + testname + ".R";
 			rCmd = getRCmd(inputDir(), expectedDir());			
 
@@ -104,7 +103,7 @@ public class RewriteEMultChainTest extends AutomatedTestBase
 			double[][] Y = getRandomMatrix(rows, cols, -1, 1, Ysparsity, 3);
 			writeInputMatrixWithMTD("X", X, true);
 			writeInputMatrixWithMTD("Y", Y, true);
-			
+
 			//execute tests
 			runTest(true, false, null, -1); 
 			runRScript(true); 

http://git-wip-us.apache.org/repos/asf/systemml/blob/f005d949/src/test/scripts/functions/misc/RewriteEMultChainOp.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteEMultChainOp.R b/src/test/scripts/functions/misc/RewriteEMultChainOp.R
deleted file mode 100644
index 6d94cc8..0000000
--- a/src/test/scripts/functions/misc/RewriteEMultChainOp.R
+++ /dev/null
@@ -1,33 +0,0 @@
-#-------------------------------------------------------------
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-# 
-#   http://www.apache.org/licenses/LICENSE-2.0
-# 
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-#-------------------------------------------------------------
-
-
-args <- commandArgs(TRUE)
-options(digits=22)
-library("Matrix")
-library("matrixStats")
-
-X = as.matrix(readMM(paste(args[1], "X.mtx", sep="")))
-Y = as.matrix(readMM(paste(args[1], "Y.mtx", sep="")))
-
-R = X * Y * X;
-
-writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep="")); 

http://git-wip-us.apache.org/repos/asf/systemml/blob/f005d949/src/test/scripts/functions/misc/RewriteEMultChainOp.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteEMultChainOp.dml b/src/test/scripts/functions/misc/RewriteEMultChainOp.dml
deleted file mode 100644
index 3992403..0000000
--- a/src/test/scripts/functions/misc/RewriteEMultChainOp.dml
+++ /dev/null
@@ -1,28 +0,0 @@
-#-------------------------------------------------------------
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-# 
-#   http://www.apache.org/licenses/LICENSE-2.0
-# 
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-#-------------------------------------------------------------
-
-
-X = read($1);
-Y = read($2);
-
-R = X * Y * X;
-
-write(R, $3);
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/systemml/blob/f005d949/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.R b/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.R
new file mode 100644
index 0000000..6d94cc8
--- /dev/null
+++ b/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.R
@@ -0,0 +1,33 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+args <- commandArgs(TRUE)
+options(digits=22)
+library("Matrix")
+library("matrixStats")
+
+X = as.matrix(readMM(paste(args[1], "X.mtx", sep="")))
+Y = as.matrix(readMM(paste(args[1], "Y.mtx", sep="")))
+
+R = X * Y * X;
+
+writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep="")); 

http://git-wip-us.apache.org/repos/asf/systemml/blob/f005d949/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.dml b/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.dml
new file mode 100644
index 0000000..3992403
--- /dev/null
+++ b/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+X = read($1);
+Y = read($2);
+
+R = X * Y * X;
+
+write(R, $3);
\ No newline at end of file


[06/23] systemml git commit: Add tests for sum(A*A*A).

Posted by mb...@apache.org.
Add tests for sum(A*A*A).

Reduced number of rows and columns in order to decrease test time.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/edbac3b6
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/edbac3b6
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/edbac3b6

Branch: refs/heads/master
Commit: edbac3b6c4361799da32cef89c8ee4e29e187c9d
Parents: f005d94
Author: Dylan Hutchison <dh...@cs.washington.edu>
Authored: Fri Jun 9 22:23:02 2017 -0700
Committer: Dylan Hutchison <dh...@cs.washington.edu>
Committed: Sun Jun 18 17:43:26 2017 -0700

----------------------------------------------------------------------
 .../ternary/ABATernaryAggregateTest.java        | 159 ++++++++++++++++++-
 .../functions/ternary/AAATernaryAggregateC.R    |  31 ++++
 .../functions/ternary/AAATernaryAggregateC.dml  |  28 ++++
 .../functions/ternary/AAATernaryAggregateRC.R   |  32 ++++
 .../functions/ternary/AAATernaryAggregateRC.dml |  29 ++++
 .../functions/ternary/ABATernaryAggregateC.dml  |   1 -
 6 files changed, 274 insertions(+), 6 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/edbac3b6/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java b/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java
index 198e9f4..1829bf0 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java
@@ -42,13 +42,15 @@ public class ABATernaryAggregateTest extends AutomatedTestBase
 {
 	private final static String TEST_NAME1 = "ABATernaryAggregateRC";
 	private final static String TEST_NAME2 = "ABATernaryAggregateC";
+	private final static String TEST_NAME3 = "AAATernaryAggregateRC";
+	private final static String TEST_NAME4 = "AAATernaryAggregateC";
 	
 	private final static String TEST_DIR = "functions/ternary/";
 	private final static String TEST_CLASS_DIR = TEST_DIR + ABATernaryAggregateTest.class.getSimpleName() + "/";
 	private final static double eps = 1e-8;
 	
-	private final static int rows = 1111;
-	private final static int cols = 1011;
+	private final static int rows = 111;
+	private final static int cols = 101;
 	
 	private final static double sparsity1 = 0.7;
 	private final static double sparsity2 = 0.3;
@@ -57,7 +59,9 @@ public class ABATernaryAggregateTest extends AutomatedTestBase
 	public void setUp() {
 		TestUtils.clearAssertionInformation();
 		addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) ); 
-		addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] { "R" }) ); 
+		addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] { "R" }) );
+		addTestConfiguration(TEST_NAME3, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] { "R" }) );
+		addTestConfiguration(TEST_NAME4, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME4, new String[] { "R" }) );
 	}
 
 	@Test
@@ -201,6 +205,151 @@ public class ABATernaryAggregateTest extends AutomatedTestBase
 	public void testTernaryAggregateCSparseMatrixCPNoRewrite() {
 		runTernaryAggregateTest(TEST_NAME2, true, false, false, ExecType.CP);
 	}
+
+
+	// another set of tests for the case of sum(A*A*A)
+
+	@Test
+	public void testTernaryAggregateRCDenseVectorCP_AAA() {
+		runTernaryAggregateTest(TEST_NAME3, false, true, true, ExecType.CP);
+	}
+
+	@Test
+	public void testTernaryAggregateRCSparseVectorCP_AAA() {
+		runTernaryAggregateTest(TEST_NAME3, true, true, true, ExecType.CP);
+	}
+
+	@Test
+	public void testTernaryAggregateRCDenseMatrixCP_AAA() {
+		runTernaryAggregateTest(TEST_NAME3, false, false, true, ExecType.CP);
+	}
+
+	@Test
+	public void testTernaryAggregateRCSparseMatrixCP_AAA() {
+		runTernaryAggregateTest(TEST_NAME3, true, false, true, ExecType.CP);
+	}
+
+	@Test
+	public void testTernaryAggregateRCDenseVectorSP_AAA() {
+		runTernaryAggregateTest(TEST_NAME3, false, true, true, ExecType.SPARK);
+	}
+
+	@Test
+	public void testTernaryAggregateRCSparseVectorSP_AAA() {
+		runTernaryAggregateTest(TEST_NAME3, true, true, true, ExecType.SPARK);
+	}
+
+	@Test
+	public void testTernaryAggregateRCDenseMatrixSP_AAA() {
+		runTernaryAggregateTest(TEST_NAME3, false, false, true, ExecType.SPARK);
+	}
+
+	@Test
+	public void testTernaryAggregateRCSparseMatrixSP_AAA() {
+		runTernaryAggregateTest(TEST_NAME3, true, false, true, ExecType.SPARK);
+	}
+
+	@Test
+	public void testTernaryAggregateRCDenseVectorMR_AAA() {
+		runTernaryAggregateTest(TEST_NAME3, false, true, true, ExecType.MR);
+	}
+
+	@Test
+	public void testTernaryAggregateRCSparseVectorMR_AAA() {
+		runTernaryAggregateTest(TEST_NAME3, true, true, true, ExecType.MR);
+	}
+
+	@Test
+	public void testTernaryAggregateRCDenseMatrixMR_AAA() {
+		runTernaryAggregateTest(TEST_NAME3, false, false, true, ExecType.MR);
+	}
+
+	@Test
+	public void testTernaryAggregateRCSparseMatrixMR_AAA() {
+		runTernaryAggregateTest(TEST_NAME3, true, false, true, ExecType.MR);
+	}
+
+	@Test
+	public void testTernaryAggregateCDenseVectorCP_AAA() {
+		runTernaryAggregateTest(TEST_NAME4, false, true, true, ExecType.CP);
+	}
+
+	@Test
+	public void testTernaryAggregateCSparseVectorCP_AAA() {
+		runTernaryAggregateTest(TEST_NAME4, true, true, true, ExecType.CP);
+	}
+
+	@Test
+	public void testTernaryAggregateCDenseMatrixCP_AAA() {
+		runTernaryAggregateTest(TEST_NAME4, false, false, true, ExecType.CP);
+	}
+
+	@Test
+	public void testTernaryAggregateCSparseMatrixCP_AAA() {
+		runTernaryAggregateTest(TEST_NAME4, true, false, true, ExecType.CP);
+	}
+
+	@Test
+	public void testTernaryAggregateCDenseVectorSP_AAA() {
+		runTernaryAggregateTest(TEST_NAME4, false, true, true, ExecType.SPARK);
+	}
+
+	@Test
+	public void testTernaryAggregateCSparseVectorSP_AAA() {
+		runTernaryAggregateTest(TEST_NAME4, true, true, true, ExecType.SPARK);
+	}
+
+	@Test
+	public void testTernaryAggregateCDenseMatrixSP_AAA() {
+		runTernaryAggregateTest(TEST_NAME4, false, false, true, ExecType.SPARK);
+	}
+
+	@Test
+	public void testTernaryAggregateCSparseMatrixSP_AAA() {
+		runTernaryAggregateTest(TEST_NAME4, true, false, true, ExecType.SPARK);
+	}
+
+	//additional tests to check default without rewrites
+
+	@Test
+	public void testTernaryAggregateRCDenseVectorCPNoRewrite_AAA() {
+		runTernaryAggregateTest(TEST_NAME3, false, true, false, ExecType.CP);
+	}
+
+	@Test
+	public void testTernaryAggregateRCSparseVectorCPNoRewrite_AAA() {
+		runTernaryAggregateTest(TEST_NAME3, true, true, false, ExecType.CP);
+	}
+
+	@Test
+	public void testTernaryAggregateRCDenseMatrixCPNoRewrite_AAA() {
+		runTernaryAggregateTest(TEST_NAME3, false, false, false, ExecType.CP);
+	}
+
+	@Test
+	public void testTernaryAggregateRCSparseMatrixCPNoRewrite_AAA() {
+		runTernaryAggregateTest(TEST_NAME3, true, false, false, ExecType.CP);
+	}
+
+	@Test
+	public void testTernaryAggregateCDenseVectorCPNoRewrite_AAA() {
+		runTernaryAggregateTest(TEST_NAME4, false, true, false, ExecType.CP);
+	}
+
+	@Test
+	public void testTernaryAggregateCSparseVectorCPNoRewrite_AAA() {
+		runTernaryAggregateTest(TEST_NAME4, true, true, false, ExecType.CP);
+	}
+
+	@Test
+	public void testTernaryAggregateCDenseMatrixCPNoRewrite_AAA() {
+		runTernaryAggregateTest(TEST_NAME4, false, false, false, ExecType.CP);
+	}
+
+	@Test
+	public void testTernaryAggregateCSparseMatrixCPNoRewrite_AAA() {
+		runTernaryAggregateTest(TEST_NAME4, true, false, false, ExecType.CP);
+	}
 	
 	
 	
@@ -230,7 +379,7 @@ public class ABATernaryAggregateTest extends AutomatedTestBase
 
 			String HOME = SCRIPT_DIR + TEST_DIR;
 			fullDMLScriptName = HOME + testname + ".dml";
-			programArgs = new String[]{"-explain","hops","-stats","-args", input("A"), output("R")};
+			programArgs = new String[]{"-explain","-stats","-args", input("A"), output("R")};
 			
 			fullRScriptName = HOME + testname + ".R";
 			rCmd = "Rscript" + " " + fullRScriptName + " " + 
@@ -254,7 +403,7 @@ public class ABATernaryAggregateTest extends AutomatedTestBase
 			//check for rewritten patterns in statistics output
 			if( rewrites && et != ExecType.MR ) {
 				String opcode = ((et == ExecType.SPARK) ? Instruction.SP_INST_PREFIX : "") + 
-					(((testname.equals(TEST_NAME1) || vectors ) ? "tak+*" : "tack+*"));
+					(((testname.equals(TEST_NAME1) || testname.equals(TEST_NAME3) || vectors ) ? "tak+*" : "tack+*"));
 				Assert.assertTrue(Statistics.getCPHeavyHitterOpCodes().contains(opcode));
 			}
 		}

http://git-wip-us.apache.org/repos/asf/systemml/blob/edbac3b6/src/test/scripts/functions/ternary/AAATernaryAggregateC.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/ternary/AAATernaryAggregateC.R b/src/test/scripts/functions/ternary/AAATernaryAggregateC.R
new file mode 100644
index 0000000..a096c2b
--- /dev/null
+++ b/src/test/scripts/functions/ternary/AAATernaryAggregateC.R
@@ -0,0 +1,31 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+args <- commandArgs(TRUE)
+options(digits=22)
+
+library("Matrix")
+
+A = as.matrix(readMM(paste(args[1], "A.mtx", sep="")))
+
+R = t(as.matrix(colSums(A * A * A)));
+
+writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep="")); 

http://git-wip-us.apache.org/repos/asf/systemml/blob/edbac3b6/src/test/scripts/functions/ternary/AAATernaryAggregateC.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/ternary/AAATernaryAggregateC.dml b/src/test/scripts/functions/ternary/AAATernaryAggregateC.dml
new file mode 100644
index 0000000..b576a4d
--- /dev/null
+++ b/src/test/scripts/functions/ternary/AAATernaryAggregateC.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+A = read($1);
+
+if(1==1){}
+
+R = colSums(A * A * A);
+
+write(R, $2);

http://git-wip-us.apache.org/repos/asf/systemml/blob/edbac3b6/src/test/scripts/functions/ternary/AAATernaryAggregateRC.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/ternary/AAATernaryAggregateRC.R b/src/test/scripts/functions/ternary/AAATernaryAggregateRC.R
new file mode 100644
index 0000000..776ddd0
--- /dev/null
+++ b/src/test/scripts/functions/ternary/AAATernaryAggregateRC.R
@@ -0,0 +1,32 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+args <- commandArgs(TRUE)
+options(digits=22)
+
+library("Matrix")
+
+A = as.matrix(readMM(paste(args[1], "A.mtx", sep="")))
+
+s = sum(A * A * A);
+R = as.matrix(s);
+
+writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep="")); 

http://git-wip-us.apache.org/repos/asf/systemml/blob/edbac3b6/src/test/scripts/functions/ternary/AAATernaryAggregateRC.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/ternary/AAATernaryAggregateRC.dml b/src/test/scripts/functions/ternary/AAATernaryAggregateRC.dml
new file mode 100644
index 0000000..7283703
--- /dev/null
+++ b/src/test/scripts/functions/ternary/AAATernaryAggregateRC.dml
@@ -0,0 +1,29 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+A = read($1);
+
+if(1==1){}
+
+s = sum(A * A * A);
+R = as.matrix(s);
+
+write(R, $2);
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/systemml/blob/edbac3b6/src/test/scripts/functions/ternary/ABATernaryAggregateC.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/ternary/ABATernaryAggregateC.dml b/src/test/scripts/functions/ternary/ABATernaryAggregateC.dml
index 78285af..737b409 100644
--- a/src/test/scripts/functions/ternary/ABATernaryAggregateC.dml
+++ b/src/test/scripts/functions/ternary/ABATernaryAggregateC.dml
@@ -21,7 +21,6 @@
 
 A = read($1);
 B = A * 2;
-C = A * 3;
 
 if(1==1){}
 


[21/23] systemml git commit: Add new `wumm` pattern to pick up element-wise multiply rewrite.

Posted by mb...@apache.org.
Add new `wumm` pattern to pick up element-wise multiply rewrite.

The new pattern recognizes when there is a `*2` or `2*` outside `W*(U%*%t(V))`.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/479b9da4
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/479b9da4
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/479b9da4

Branch: refs/heads/master
Commit: 479b9da4e6c605871a914ccb4b06ab6da5de21ed
Parents: e93c487
Author: Dylan Hutchison <dh...@cs.washington.edu>
Authored: Thu Jul 13 01:14:48 2017 -0700
Committer: Dylan Hutchison <dh...@cs.washington.edu>
Committed: Thu Jul 13 01:14:48 2017 -0700

----------------------------------------------------------------------
 .../sysml/hops/rewrite/ProgramRewriter.java     |  2 +-
 .../RewriteAlgebraicSimplificationDynamic.java  | 44 +++++++++++++++++++-
 2 files changed, 43 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/479b9da4/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
index 59565df..7c4f861 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
@@ -54,7 +54,7 @@ public class ProgramRewriter
 	private static final Log LOG = LogFactory.getLog(ProgramRewriter.class.getName());
 	
 	//internal local debug level
-	private static final boolean LDEBUG = false; 
+	private static final boolean LDEBUG = false;
 	private static final boolean CHECK = false;
 	
 	private ArrayList<HopRewriteRule> _dagRuleSet = null;

http://git-wip-us.apache.org/repos/asf/systemml/blob/479b9da4/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index 8cd71f4..6246270 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -29,11 +29,11 @@ import org.apache.sysml.hops.AggUnaryOp;
 import org.apache.sysml.hops.BinaryOp;
 import org.apache.sysml.hops.DataGenOp;
 import org.apache.sysml.hops.Hop;
-import org.apache.sysml.hops.QuaternaryOp;
 import org.apache.sysml.hops.Hop.AggOp;
 import org.apache.sysml.hops.Hop.DataGenMethod;
 import org.apache.sysml.hops.Hop.Direction;
 import org.apache.sysml.hops.Hop.OpOp1;
+import org.apache.sysml.hops.Hop.OpOp2;
 import org.apache.sysml.hops.Hop.OpOp3;
 import org.apache.sysml.hops.Hop.OpOp4;
 import org.apache.sysml.hops.Hop.ParamBuiltinOp;
@@ -44,7 +44,7 @@ import org.apache.sysml.hops.LeftIndexingOp;
 import org.apache.sysml.hops.LiteralOp;
 import org.apache.sysml.hops.OptimizerUtils;
 import org.apache.sysml.hops.ParameterizedBuiltinOp;
-import org.apache.sysml.hops.Hop.OpOp2;
+import org.apache.sysml.hops.QuaternaryOp;
 import org.apache.sysml.hops.ReorgOp;
 import org.apache.sysml.hops.TernaryOp;
 import org.apache.sysml.hops.UnaryOp;
@@ -1959,6 +1959,46 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
 			appliedPattern = true;
 			LOG.debug("Applied simplifyWeightedUnaryMM1 (line "+hi.getBeginLine()+")");	
 		}
+
+		//Pattern 1.5) (W*(U%*%t(V))*2 or 2*(W*(U%*%t(V))
+		if( !appliedPattern
+				&& hi instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp)hi).getOp(), OpOp2.MULT)
+				&& (HopRewriteUtils.isLiteralOfValue(hi.getInput().get(0), 2)
+					|| HopRewriteUtils.isLiteralOfValue(hi.getInput().get(1), 2)))
+		{
+			final Hop nl; // non-literal
+			if( hi.getInput().get(0) instanceof LiteralOp ) {
+				nl = hi.getInput().get(1);
+			} else {
+				nl = hi.getInput().get(0);
+			}
+
+			if (       HopRewriteUtils.isBinary(nl, OpOp2.MULT)
+					&& HopRewriteUtils.isEqualSize(nl.getInput().get(0), nl.getInput().get(1)) //prevent mv
+					&& nl.getDim2() > 1 //not applied for vector-vector mult
+					&& nl.getInput().get(0).getDataType() == DataType.MATRIX
+					&& nl.getInput().get(0).getDim2() > nl.getInput().get(0).getColsInBlock()
+					&& HopRewriteUtils.isOuterProductLikeMM(nl.getInput().get(1))
+					&& (((AggBinaryOp) nl.getInput().get(1)).checkMapMultChain() == ChainType.NONE || nl.getInput().get(1).getInput().get(1).getDim2() > 1) //no mmchain
+					&& HopRewriteUtils.isSingleBlock(nl.getInput().get(1).getInput().get(0),true) )
+			{
+				final Hop W = nl.getInput().get(0);
+				final Hop U = nl.getInput().get(1).getInput().get(0);
+				Hop V = nl.getInput().get(1).getInput().get(1);
+				if( !HopRewriteUtils.isTransposeOperation(V) )
+					V = HopRewriteUtils.createTranspose(V);
+				else
+					V = V.getInput().get(0);
+
+				hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX, ValueType.DOUBLE,
+						OpOp4.WUMM, W, U, V, true, null, OpOp2.MULT);
+				hnew.setOutputBlocksizes(W.getRowsInBlock(), W.getColsInBlock());
+				hnew.refreshSizeInformation();
+
+				appliedPattern = true;
+				LOG.debug("Applied simplifyWeightedUnaryMM2.7 (line "+hi.getBeginLine()+")");
+			}
+		}
 		
 		//Pattern 2) (W*sop(U%*%t(V),c)) for known sop translating to unary ops
 		if( !appliedPattern


[19/23] systemml git commit: Move to dynamic rewrites. Do not rewrite if top-level dims unknown.

Posted by mb...@apache.org.
Move to dynamic rewrites. Do not rewrite if top-level dims unknown.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/d18a4c80
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/d18a4c80
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/d18a4c80

Branch: refs/heads/master
Commit: d18a4c80dece566ddbad34a7f3c2f70ce544023e
Parents: c4e9228
Author: Dylan Hutchison <dh...@cs.washington.edu>
Authored: Tue Jul 11 22:54:31 2017 -0700
Committer: Dylan Hutchison <dh...@cs.washington.edu>
Committed: Tue Jul 11 22:54:31 2017 -0700

----------------------------------------------------------------------
 .../java/org/apache/sysml/hops/rewrite/ProgramRewriter.java    | 6 +++---
 .../hops/rewrite/RewriteElementwiseMultChainOptimization.java  | 5 +++--
 2 files changed, 6 insertions(+), 5 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/d18a4c80/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
index a1ff5bc..59565df 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
@@ -96,8 +96,6 @@ public class ProgramRewriter
 			_dagRuleSet.add(     new RewriteRemoveUnnecessaryCasts()             );		
 			if( OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION )
 				_dagRuleSet.add( new RewriteCommonSubexpressionElimination()     );
-			if ( OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES)
-				_dagRuleSet.add( new RewriteElementwiseMultChainOptimization()   ); //dependency: cse
 			if( OptimizerUtils.ALLOW_CONSTANT_FOLDING )
 				_dagRuleSet.add( new RewriteConstantFolding()                    ); //dependency: cse
 			if( OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION )
@@ -125,7 +123,9 @@ public class ProgramRewriter
 		// DYNAMIC REWRITES (which do require size information)
 		if( dynamicRewrites )
 		{
-			_dagRuleSet.add(     new RewriteMatrixMultChainOptimization()         ); //dependency: cse 
+			_dagRuleSet.add(     new RewriteMatrixMultChainOptimization()         ); //dependency: cse
+			if ( OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES)
+				_dagRuleSet.add( new RewriteElementwiseMultChainOptimization()    ); //dependency: cse
 			
 			if( OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION )
 			{

http://git-wip-us.apache.org/repos/asf/systemml/blob/d18a4c80/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
index de1def8..1f85bbf 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
@@ -44,6 +44,7 @@ import org.apache.sysml.parser.Expression;
  *
  * Does not rewrite in the presence of foreign parents in the middle of the e-wise multiply chain,
  * since foreign parents may rely on the individual results.
+ * Does not perform rewrites on an element-wise multiply if its dimensions are unknown.
  *
  * The new order of element-wise multiply chains is as follows:
  * <pre>
@@ -81,8 +82,8 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
 			return root;
 		root.setVisited();
 
-		// 1. Find immediate subtree of EMults.
-		if (isBinaryMult(root)) {
+		// 1. Find immediate subtree of EMults. Check dimsKnown.
+		if (isBinaryMult(root) && root.dimsKnown()) {
 			final Hop left = root.getInput().get(0), right = root.getInput().get(1);
 			final Set<BinaryOp> emults = new HashSet<>();
 			final Map<Hop, Integer> leaves = new HashMap<>(); // poor man's HashMultiset


[18/23] systemml git commit: Get rid of leftover Guava dependency

Posted by mb...@apache.org.
Get rid of leftover Guava dependency


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/c4e9228e
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/c4e9228e
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/c4e9228e

Branch: refs/heads/master
Commit: c4e9228ed0b86789f8b41a533bf112d681a30318
Parents: 3c4d777
Author: Dylan Hutchison <dh...@cs.washington.edu>
Authored: Tue Jul 11 22:46:33 2017 -0700
Committer: Dylan Hutchison <dh...@cs.washington.edu>
Committed: Tue Jul 11 22:46:33 2017 -0700

----------------------------------------------------------------------
 ...RewriteElementwiseMultChainOptimization.java | 32 +++++++++++---------
 1 file changed, 17 insertions(+), 15 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/c4e9228e/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
index 486072b..de1def8 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
@@ -21,6 +21,7 @@ package org.apache.sysml.hops.rewrite;
 
 import java.util.ArrayList;
 import java.util.Comparator;
+import java.util.HashMap;
 import java.util.HashSet;
 import java.util.Iterator;
 import java.util.Map;
@@ -34,15 +35,12 @@ import org.apache.sysml.hops.HopsException;
 import org.apache.sysml.hops.LiteralOp;
 import org.apache.sysml.parser.Expression;
 
-import com.google.common.collect.HashMultiset;
-import com.google.common.collect.Multiset;
-
 /**
  * Prerequisite: RewriteCommonSubexpressionElimination must run before this rule.
  *
  * Rewrite a chain of element-wise multiply hops that contain identical elements.
  * For example `(B * A) * B` is rewritten to `A * (B^2)` (or `(B^2) * A`), where `^` is element-wise power.
- * The order of the multiplicands depends on their data types, dimentions (matrix or vector), and sparsity.
+ * The order of the multiplicands depends on their data types, dimensions (matrix or vector), and sparsity.
  *
  * Does not rewrite in the presence of foreign parents in the middle of the e-wise multiply chain,
  * since foreign parents may rely on the individual results.
@@ -87,7 +85,7 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
 		if (isBinaryMult(root)) {
 			final Hop left = root.getInput().get(0), right = root.getInput().get(1);
 			final Set<BinaryOp> emults = new HashSet<>();
-			final Multiset<Hop> leaves = HashMultiset.create();
+			final Map<Hop, Integer> leaves = new HashMap<>(); // poor man's HashMultiset
 			findEMultsAndLeaves((BinaryOp)root, emults, leaves);
 
 			// 2. Ensure it is profitable to do a rewrite.
@@ -109,7 +107,7 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
 					final Hop newRoot = HopRewriteUtils.rewireAllParentChildReferences(root, replacement);
 
 					// 6. Recurse at leaves (no need to repeat the interior emults)
-					for (final Hop leaf : leaves.elementSet()) {
+					for (final Hop leaf : leaves.keySet()) {
 						recurseInputs(leaf);
 					}
 					return newRoot;
@@ -131,15 +129,15 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
 		}
 	}
 
-	private static Hop constructReplacement(final Set<BinaryOp> emults, final Multiset<Hop> leaves) {
+	private static Hop constructReplacement(final Set<BinaryOp> emults, final Map<Hop, Integer> leaves) {
 		// Sort by data type
 		final SortedMap<Hop,Integer> sorted = new TreeMap<>(compareByDataType);
-		for (final Multiset.Entry<Hop> entry : leaves.entrySet()) {
-			final Hop h = entry.getElement();
+		for (final Map.Entry<Hop, Integer> entry : leaves.entrySet()) {
+			final Hop h = entry.getKey();
 			// unlink parents that are in the emult set(we are throwing them away)
 			// keep other parents
 			h.getParent().removeIf(parent -> parent instanceof BinaryOp && emults.contains(parent));
-			sorted.put(h, entry.getCount());
+			sorted.put(h, entry.getValue());
 		}
 		// sorted contains all leaves, sorted by data type, stripped from their parents
 
@@ -244,7 +242,7 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
 	}
 
 	/**
-	 * A Comparator that orders Hops by their data type, dimention, and sparsity.
+	 * A Comparator that orders Hops by their data type, dimension, and sparsity.
 	 * The order is as follows:
 	 * 		scalars < col vectors < row vectors <
 	 *      non-vector matrices ordered by sparsity (higher nnz last, unknown sparsity last) >
@@ -324,7 +322,7 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
 	 * @param emults Out parameter. The set of BinaryOp element-wise multiply hops in the emult chain (including root).
 	 * @param leaves Out parameter. The multiset of multiplicands in the emult chain.
 	 */
-	private static void findEMultsAndLeaves(final BinaryOp root, final Set<BinaryOp> emults, final Multiset<Hop> leaves) {
+	private static void findEMultsAndLeaves(final BinaryOp root, final Set<BinaryOp> emults, final Map<Hop, Integer> leaves) {
 		// Because RewriteCommonSubexpressionElimination already ran, it is safe to compare by equality.
 		emults.add(root);
 
@@ -334,12 +332,16 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
 		if (isBinaryMult(left))
 			findEMultsAndLeaves((BinaryOp) left, emults, leaves);
 		else
-			leaves.add(left);
+			addMultiset(leaves, left);
 
 		if (isBinaryMult(right))
 			findEMultsAndLeaves((BinaryOp) right, emults, leaves);
 		else
-			leaves.add(right);
+			addMultiset(leaves, right);
+	}
+
+	private static <K> void addMultiset(final Map<K,Integer> map, final K k) {
+		map.put(k, map.getOrDefault(k, 0) + 1);
 	}
 
 	/**
@@ -348,7 +350,7 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
 	 * @param leaves The multiset of multiplicands in the emult chain.
 	 * @return If the multiset is worth optimizing.
 	 */
-	private static boolean isOptimizable(Set<BinaryOp> emults, final Multiset<Hop> leaves) {
+	private static boolean isOptimizable(Set<BinaryOp> emults, final Map<Hop, Integer> leaves) {
 		return emults.size() >= 2;
 	}
 }


[03/23] systemml git commit: Add name to sorting of EMult rewrite. Handle Ternary A*A*B case.

Posted by mb...@apache.org.
Add name to sorting of EMult rewrite. Handle Ternary A*A*B case.

AggUnaryOp now constructs the TernaryOperator (A,A,B) instead of (A^2,B,1).


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/ff8c836c
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/ff8c836c
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/ff8c836c

Branch: refs/heads/master
Commit: ff8c836c7b736dbd7b7651ac792a6d8c23989c98
Parents: eb0599d
Author: Dylan Hutchison <dh...@cs.washington.edu>
Authored: Fri Jun 9 13:18:19 2017 -0700
Committer: Dylan Hutchison <dh...@cs.washington.edu>
Committed: Sun Jun 18 17:43:18 2017 -0700

----------------------------------------------------------------------
 .../java/org/apache/sysml/hops/AggUnaryOp.java  | 62 ++++++++++++++------
 .../apache/sysml/hops/rewrite/RewriteEMult.java |  4 +-
 2 files changed, 48 insertions(+), 18 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/ff8c836c/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
index ee4ded2..4573b66 100644
--- a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
+++ b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
@@ -516,7 +516,6 @@ public class AggUnaryOp extends Hop implements MultiThreadedHop
 				}
 			}
 		}
-		
 		return ret;
 	}
 	
@@ -631,24 +630,53 @@ public class AggUnaryOp extends Hop implements MultiThreadedHop
 		Hop input11 = input1.getInput().get(0);
 		Hop input12 = input1.getInput().get(1);
 		
-		Lop in1 = null;
-		Lop in2 = null;
-		Lop in3 = null;
+		Lop in1 = null, in2 = null, in3 = null;
+		boolean handled = false;
 		
-		if( input11 instanceof BinaryOp && ((BinaryOp)input11).getOp()==OpOp2.MULT )
-		{
-			in1 = input11.getInput().get(0).constructLops();
-			in2 = input11.getInput().get(1).constructLops();
-			in3 = input12.constructLops();
-		}
-		else if( input12 instanceof BinaryOp && ((BinaryOp)input12).getOp()==OpOp2.MULT )
-		{
-			in1 = input11.constructLops();
-			in2 = input12.getInput().get(0).constructLops();
-			in3 = input12.getInput().get(1).constructLops();
+		if( input11 instanceof BinaryOp ) {
+			BinaryOp b11 = (BinaryOp)input11;
+			switch (b11.getOp()) {
+			case MULT: // A*B*C case
+				in1 = input11.getInput().get(0).constructLops();
+				in2 = input11.getInput().get(1).constructLops();
+				in3 = input12.constructLops();
+				handled = true;
+				break;
+			case POW: // A*A*B case
+				Hop b112 = b11.getInput().get(1);
+				if ( !(input12 instanceof BinaryOp && ((BinaryOp)input12).getOp()==OpOp2.MULT)
+						&& b112 instanceof LiteralOp
+						&& ((LiteralOp)b112).getLongValue() == 2) {
+					in1 = b11.getInput().get(0).constructLops();
+					in2 = in1;
+					in3 = input12.constructLops();
+					handled = true;
+				}
+				break;
+			}
+		} else if( input12 instanceof BinaryOp ) {
+			BinaryOp b12 = (BinaryOp)input12;
+			switch (b12.getOp()) {
+			case MULT: // A*B*C case
+				in1 = input11.constructLops();
+				in2 = input12.getInput().get(0).constructLops();
+				in3 = input12.getInput().get(1).constructLops();
+				handled = true;
+				break;
+			case POW: // A*B*B case
+				Hop b112 = b12.getInput().get(1);
+				if ( b112 instanceof LiteralOp
+						&& ((LiteralOp)b112).getLongValue() == 2) {
+					in1 = b12.getInput().get(0).constructLops();
+					in2 = in1;
+					in3 = input11.constructLops();
+					handled = true;
+				}
+				break;
+			}
 		}
-		else
-		{
+
+		if (!handled) {
 			in1 = input11.constructLops();
 			in2 = input12.constructLops();
 			in3 = new LiteralOp(1).constructLops();

http://git-wip-us.apache.org/repos/asf/systemml/blob/ff8c836c/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
index 2c9e5cb..66da6fa 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
@@ -145,7 +145,9 @@ public class RewriteEMult extends HopRewriteRule {
 		return HopRewriteUtils.createBinary(hop, new LiteralOp(cnt), Hop.OpOp2.POW);
 	}
 
-	private static Comparator<Hop> compareByDataType = Comparator.comparing(Hop::getDataType).thenComparing(Object::hashCode);
+	private static Comparator<Hop> compareByDataType = Comparator.comparing(Hop::getDataType)
+			.thenComparing(Hop::getName)
+			.thenComparingInt(Object::hashCode);
 
 	private static boolean checkForeignParent(final Set<BinaryOp> emults, final BinaryOp child) {
 		final ArrayList<Hop> parents = child.getParent();


[23/23] systemml git commit: [SYSTEMML-1663] Fix and enable rewrite element-wise multiply chains

Posted by mb...@apache.org.
[SYSTEMML-1663] Fix and enable rewrite element-wise multiply chains

Closes #567.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/85e3a963
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/85e3a963
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/85e3a963

Branch: refs/heads/master
Commit: 85e3a9631081e8f44a21e4c7e99bbcf03859cba7
Parents: 1b3dff0 14bd65a
Author: Matthias Boehm <mb...@gmail.com>
Authored: Fri Jul 14 21:10:18 2017 -0700
Committer: Matthias Boehm <mb...@gmail.com>
Committed: Fri Jul 14 21:10:19 2017 -0700

----------------------------------------------------------------------
 .../sysml/hops/rewrite/ProgramRewriter.java     |   8 +-
 .../RewriteAlgebraicSimplificationDynamic.java  |  45 ++++-
 ...RewriteElementwiseMultChainOptimization.java | 189 +++++++++++++------
 ...ElementwiseMultChainOptimizationAllTest.java | 134 +++++++++++++
 ...iteElementwiseMultChainOptimizationTest.java |   4 +-
 .../functions/misc/RewriteEMultChainOpAll.R     |  37 ++++
 .../functions/misc/RewriteEMultChainOpAll.dml   |  31 +++
 7 files changed, 379 insertions(+), 69 deletions(-)
----------------------------------------------------------------------



[11/23] systemml git commit: Review comments, part 2

Posted by mb...@apache.org.
Review comments, part 2


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/d6d37952
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/d6d37952
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/d6d37952

Branch: refs/heads/master
Commit: d6d37952bdb1cd76d61c376b2051292ca272ee0a
Parents: 737f93b
Author: Dylan Hutchison <dh...@cs.washington.edu>
Authored: Sun Jun 11 02:27:34 2017 -0700
Committer: Dylan Hutchison <dh...@cs.washington.edu>
Committed: Sun Jun 18 17:43:45 2017 -0700

----------------------------------------------------------------------
 .../sysml/hops/rewrite/HopRewriteUtils.java     | 23 +++++-------
 ...RewriteElementwiseMultChainOptimization.java | 38 +++++++-------------
 .../java/org/apache/sysml/utils/Explain.java    |  2 +-
 3 files changed, 22 insertions(+), 41 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/d6d37952/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
index 8f71359..b98901a 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
@@ -246,22 +246,15 @@ public class HopRewriteUtils
 	 * Replace an old Hop with a replacement Hop.
 	 * If the old Hop has no parents, then return the replacement.
 	 * Otherwise rewire each of the Hop's parents into the replacement and return the replacement.
-	 * @param old To be replaced
-	 * @param replacement The replacement
-	 * @return replacement
+	 * @param hold To be replaced
+	 * @param hnew The replacement
+	 * @return hnew
 	 */
-	public static Hop replaceHop(final Hop old, final Hop replacement) {
-		if (old.getParent().isEmpty())
-			return replacement; // new old!
-		HopRewriteUtils.rewireAllParentChildReferences(old, replacement);
-		return replacement;
-	}
-
-
-	public static void rewireAllParentChildReferences( Hop hold, Hop hnew ) {
-		ArrayList<Hop> parents = new ArrayList<Hop>(hold.getParent());
-		for( Hop lparent : parents )
-			HopRewriteUtils.replaceChildReference(lparent, hold, hnew);
+	public static Hop rewireAllParentChildReferences( Hop hold, Hop hnew ) {
+		ArrayList<Hop> parents = hold.getParent();
+		while (!parents.isEmpty())
+			HopRewriteUtils.replaceChildReference(parents.get(0), hold, hnew);
+		return hnew;
 	}
 	
 	public static void replaceChildReference( Hop parent, Hop inOld, Hop inNew ) {

http://git-wip-us.apache.org/repos/asf/systemml/blob/d6d37952/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
index 1dd5813..91b7306 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
@@ -98,7 +98,7 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
 								emults.size(), root.getHopID(), replacement.getHopID()));
 
 					// 5. Replace root with replacement
-					final Hop newRoot = HopRewriteUtils.replaceHop(root, replacement);
+					final Hop newRoot = HopRewriteUtils.rewireAllParentChildReferences(root, replacement);
 
 					// 6. Recurse at leaves (no need to repeat the interior emults)
 					for (final Hop leaf : leaves.elementSet()) {
@@ -166,6 +166,18 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
 	 * Disambiguate by Hop ID.
 	 */
 	private static final Comparator<Hop> compareByDataType = new Comparator<Hop>() {
+		private final int[] orderDataType = new int[Expression.DataType.values().length];
+		{
+			for (int i = 0, valuesLength = Expression.DataType.values().length; i < valuesLength; i++)
+				switch(Expression.DataType.values()[i]) {
+				case SCALAR: orderDataType[i] = 4; break;
+				case MATRIX: orderDataType[i] = 3; break;
+				case FRAME:  orderDataType[i] = 2; break;
+				case OBJECT: orderDataType[i] = 1; break;
+				case UNKNOWN:orderDataType[i] = 0; break;
+				}
+		}
+
 		@Override
 		public final int compare(Hop o1, Hop o2) {
 			int c = Integer.compare(orderDataType[o1.getDataType().ordinal()], orderDataType[o2.getDataType().ordinal()]);
@@ -198,30 +210,6 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
 			if (c != 0) return c;
 			return Long.compare(o1.getHopID(), o2.getHopID());
 		}
-		private final int[] orderDataType;
-		{
-			Expression.DataType[] dtValues = Expression.DataType.values();
-			orderDataType = new int[dtValues.length];
-			for (int i = 0, valuesLength = dtValues.length; i < valuesLength; i++) {
-				switch(dtValues[i]) {
-				case SCALAR:
-					orderDataType[i] = 4;
-					break;
-				case MATRIX:
-					orderDataType[i] = 3;
-					break;
-				case FRAME:
-					orderDataType[i] = 2;
-					break;
-				case OBJECT:
-					orderDataType[i] = 1;
-					break;
-				case UNKNOWN:
-					orderDataType[i] = 0;
-					break;
-				}
-			}
-		}
 	}.reversed();
 
 	/**

http://git-wip-us.apache.org/repos/asf/systemml/blob/d6d37952/src/main/java/org/apache/sysml/utils/Explain.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/utils/Explain.java b/src/main/java/org/apache/sysml/utils/Explain.java
index 450c6e5..6451396 100644
--- a/src/main/java/org/apache/sysml/utils/Explain.java
+++ b/src/main/java/org/apache/sysml/utils/Explain.java
@@ -76,7 +76,7 @@ public class Explain
 	//internal configuration parameters
 	private static final boolean REPLACE_SPECIAL_CHARACTERS = true;	
 	private static final boolean SHOW_MEM_ABOVE_BUDGET      = true;
-	private static final boolean SHOW_LITERAL_HOPS          = true;
+	private static final boolean SHOW_LITERAL_HOPS          = false;
 	private static final boolean SHOW_DATA_DEPENDENCIES     = true;
 	private static final boolean SHOW_DATA_FLOW_PROPERTIES  = true;
 


[08/23] systemml git commit: Document RewriteEMult. Add smart recursion.

Posted by mb...@apache.org.
Document RewriteEMult. Add smart recursion.

RewriteEMult now rewrites emult chains deeper than the top-most one.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/d88f867f
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/d88f867f
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/d88f867f

Branch: refs/heads/master
Commit: d88f867fd0384954dce9e6ce4d65f02f1054bc5e
Parents: a5846bb
Author: Dylan Hutchison <dh...@cs.washington.edu>
Authored: Sat Jun 10 01:17:36 2017 -0700
Committer: Dylan Hutchison <dh...@cs.washington.edu>
Committed: Sun Jun 18 17:43:33 2017 -0700

----------------------------------------------------------------------
 .../sysml/hops/rewrite/HopRewriteUtils.java     |  2 +
 .../apache/sysml/hops/rewrite/RewriteEMult.java | 90 +++++++++++++-------
 .../org/apache/sysml/parser/Expression.java     |  1 -
 3 files changed, 62 insertions(+), 31 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/d88f867f/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
index 4d23cb9..17ac4ec 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
@@ -246,6 +246,8 @@ public class HopRewriteUtils
 	 * Replace an old Hop with a replacement Hop.
 	 * If the old Hop has no parents, then return the replacement.
 	 * Otherwise rewire each of the Hop's parents into the replacement and return the replacement.
+	 * @param old To be replaced
+	 * @param replacement The replacement
 	 * @return replacement
 	 */
 	public static Hop replaceHop(final Hop old, final Hop replacement) {

http://git-wip-us.apache.org/repos/asf/systemml/blob/d88f867f/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
index d483a08..5cd1471 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
@@ -42,6 +42,7 @@ import com.google.common.collect.Multiset;
  *
  * Rewrite a chain of element-wise multiply hops that contain identical elements.
  * For example `(B * A) * B` is rewritten to `A * (B^2)` (or `(B^2) * A`), where `^` is element-wise power.
+ * The order of the multiplicands depends on their data types, dimentions (matrix or vector), and sparsity.
  *
  * Does not rewrite in the presence of foreign parents in the middle of the e-wise multiply chain,
  * since foreign parents may rely on the individual results.
@@ -74,18 +75,15 @@ public class RewriteEMult extends HopRewriteRule {
 			return root;
 		root.setVisited();
 
-		final ArrayList<Hop> rootInputs = root.getInput();
-
 		// 1. Find immediate subtree of EMults.
 		if (isBinaryMult(root)) {
-			final Hop left = rootInputs.get(0), right = rootInputs.get(1);
-			final BinaryOp r = (BinaryOp)root;
+			final Hop left = root.getInput().get(0), right = root.getInput().get(1);
 			final Set<BinaryOp> emults = new HashSet<>();
 			final Multiset<Hop> leaves = HashMultiset.create();
-			findEMultsAndLeaves(r, emults, leaves);
+			findEMultsAndLeaves((BinaryOp)root, emults, leaves);
 
 			// 2. Ensure it is profitable to do a rewrite.
-			if (isOptimizable(leaves)) {
+			if (isOptimizable(emults, leaves)) {
 				// 3. Check for foreign parents.
 				// A foreign parent is a parent of some EMult that is not in the set.
 				// Foreign parents destroy correctness of this rewrite.
@@ -94,25 +92,35 @@ public class RewriteEMult extends HopRewriteRule {
 				if (okay) {
 					// 4. Construct replacement EMults for the leaves
 					final Hop replacement = constructReplacement(leaves);
-					// 5. Replace root with replacement
 					if (LOG.isDebugEnabled())
 						LOG.debug(String.format(
 								"Element-wise multiply chain rewrite of %d e-mults at sub-dag %d to new sub-dag %d",
 								emults.size(), root.getHopID(), replacement.getHopID()));
-					replacement.setVisited();
-					return HopRewriteUtils.replaceHop(root, replacement);
+
+					// 5. Replace root with replacement
+					final Hop newRoot = HopRewriteUtils.replaceHop(root, replacement);
+
+					// 6. Recurse at leaves (no need to repeat the interior emults)
+					for (final Hop leaf : leaves.elementSet()) {
+						recurseInputs(leaf);
+					}
+					return newRoot;
 				}
 			}
 		}
-
 		// This rewrite is not applicable to the current root.
 		// Try the root's children.
-		for (int i = 0; i < rootInputs.size(); i++) {
-			final Hop input = rootInputs.get(i);
+		recurseInputs(root);
+		return root;
+	}
+
+	private static void recurseInputs(final Hop parent) {
+		final ArrayList<Hop> inputs = parent.getInput();
+		for (int i = 0; i < inputs.size(); i++) {
+			final Hop input = inputs.get(i);
 			final Hop newInput = rule_RewriteEMult(input);
-			rootInputs.set(i, newInput);
+			inputs.set(i, newInput);
 		}
-		return root;
 	}
 
 	private static Hop constructReplacement(final Multiset<Hop> leaves) {
@@ -133,6 +141,7 @@ public class RewriteEMult extends HopRewriteRule {
 		for (int i = 1; i < sorted.size(); i++) {
 			final Hop second = constructPower(iterator.next());
 			first = HopRewriteUtils.createBinary(second, first, Hop.OpOp2.MULT);
+			first.setVisited();
 		}
 		return first;
 	}
@@ -141,16 +150,21 @@ public class RewriteEMult extends HopRewriteRule {
 		final Hop hop = entry.getKey();
 		final int cnt = entry.getValue();
 		assert(cnt >= 1);
-		if (cnt == 1) return hop;
-		return HopRewriteUtils.createBinary(hop, new LiteralOp(cnt), Hop.OpOp2.POW);
+		if (cnt == 1)
+			return hop; // don't set this visited... we will visit this next
+		Hop pow = HopRewriteUtils.createBinary(hop, new LiteralOp(cnt), Hop.OpOp2.POW);
+		pow.setVisited();
+		return pow;
 	}
 
-
-
-	// Order: scalars > row vectors > col vectors >
-	//        non-vector matrices ordered by sparsity (higher nnz first, unknown sparsity last) >
-	//        other data types
-	// disambiguate by Hop ID
+	/**
+	 * A Comparator that orders Hops by their data type, dimention, and sparsity.
+	 * The order is as follows:
+	 * 		scalars > row vectors > col vectors >
+	 *      non-vector matrices ordered by sparsity (higher nnz first, unknown sparsity last) >
+	 *      other data types.
+	 * Disambiguate by Hop ID.
+	 */
 	private static final Comparator<Hop> compareByDataType = new Comparator<Hop>() {
 		@Override
 		public final int compare(Hop o1, Hop o2) {
@@ -211,6 +225,12 @@ public class RewriteEMult extends HopRewriteRule {
 		}
 	};
 
+	/**
+	 * Check if a node has a parent that is not in the set of emults. Recursively check children who are also emults.
+	 * @param emults The set of BinaryOp element-wise multiply hops in the emult chain.
+	 * @param child An interior emult hop in the emult chain dag.
+	 * @return Whether this interior emult or any child emult has a foreign parent.
+	 */
 	private static boolean checkForeignParent(final Set<BinaryOp> emults, final BinaryOp child) {
 		final ArrayList<Hop> parents = child.getParent();
 		if (parents.size() > 1)
@@ -227,7 +247,10 @@ public class RewriteEMult extends HopRewriteRule {
 	}
 
 	/**
-	 * Create a set of the counts of all BinaryOp MULTs in the immediate subtree, starting with root.
+	 * Create a set of the counts of all BinaryOp MULTs in the immediate subtree, starting with root, recursively.
+	 * @param root Root of sub-dag
+	 * @param emults Out parameter. The set of BinaryOp element-wise multiply hops in the emult chain (including root).
+	 * @param leaves Out parameter. The multiset of multiplicands in the emult chain.
 	 */
 	private static void findEMultsAndLeaves(final BinaryOp root, final Set<BinaryOp> emults, final Multiset<Hop> leaves) {
 		// Because RewriteCommonSubexpressionElimination already ran, it is safe to compare by equality.
@@ -243,12 +266,19 @@ public class RewriteEMult extends HopRewriteRule {
 		else leaves.add(right);
 	}
 
-	/** Only optimize a subtree of EMults if at least one leaf occurs more than once. */
-	private static boolean isOptimizable(final Multiset<Hop> set) {
-		for (Multiset.Entry<Hop> hopEntry : set.entrySet()) {
-			if (hopEntry.getCount() > 1)
-				return true;
-		}
-		return false;
+	/**
+	 * Only optimize a subtree of emults if there are at least two emults.
+	 * @param emults The set of BinaryOp element-wise multiply hops in the emult chain.
+	 * @param leaves The multiset of multiplicands in the emult chain.
+	 * @return If the multiset is worth optimizing.
+	 */
+	private static boolean isOptimizable(Set<BinaryOp> emults, final Multiset<Hop> leaves) {
+		// Old criterion: there should be at least one repeated leaf
+//		for (Multiset.Entry<Hop> hopEntry : leaves.entrySet()) {
+//			if (hopEntry.getCount() > 1)
+//				return true;
+//		}
+//		return false;
+		return emults.size() >= 2;
 	}
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/d88f867f/src/main/java/org/apache/sysml/parser/Expression.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/Expression.java b/src/main/java/org/apache/sysml/parser/Expression.java
index b944e29..9ee3fba 100644
--- a/src/main/java/org/apache/sysml/parser/Expression.java
+++ b/src/main/java/org/apache/sysml/parser/Expression.java
@@ -162,7 +162,6 @@ public abstract class Expression
 	 * Data types (matrix, scalar, frame, object, unknown).
 	 */
 	public enum DataType {
-		// Careful: the order of these enums is significant! See RewriteEMult.comparatorByDataType
 		MATRIX, SCALAR, FRAME, OBJECT, UNKNOWN;
 		
 		public boolean isMatrix() {


[07/23] systemml git commit: simplifyDotProductSum shall not interfere with tak+*

Posted by mb...@apache.org.
simplifyDotProductSum shall not interfere with tak+*

Added conditions to the dynamic algebraic rewrite simplifyDotProductSum
that do not apply the optimization for (A^2)*B or B*(A^2),
since TernaryAggregate handles these.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/a5846bbb
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/a5846bbb
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/a5846bbb

Branch: refs/heads/master
Commit: a5846bbb383c655189963bffefed1c0db4ffcc89
Parents: edbac3b
Author: Dylan Hutchison <dh...@cs.washington.edu>
Authored: Fri Jun 9 23:58:16 2017 -0700
Committer: Dylan Hutchison <dh...@cs.washington.edu>
Committed: Sun Jun 18 17:43:30 2017 -0700

----------------------------------------------------------------------
 .../hops/rewrite/RewriteAlgebraicSimplificationDynamic.java | 9 ++++++++-
 1 file changed, 8 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/a5846bbb/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index ad80c05..166af2f 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -2050,7 +2050,14 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
 			else if( HopRewriteUtils.isBinary(hi2, OpOp2.MULT, 1) //no other consumer than sum
 					&& hi2.getInput().get(0).getDim2()==1 && hi2.getInput().get(1).getDim2()==1
 					&& !HopRewriteUtils.isBinary(hi2.getInput().get(0), OpOp2.MULT)
-					&& !HopRewriteUtils.isBinary(hi2.getInput().get(1), OpOp2.MULT) )
+					&& !HopRewriteUtils.isBinary(hi2.getInput().get(1), OpOp2.MULT)
+					&& !(HopRewriteUtils.isBinary(hi2.getInput().get(0), OpOp2.POW)      // do not rewrite (A^2)*B
+						&& hi2.getInput().get(0).getInput().get(1) instanceof LiteralOp  // let tak+* handle it
+						&& ((LiteralOp)hi2.getInput().get(0).getInput().get(1)).getLongValue() == 2)
+					&& !(HopRewriteUtils.isBinary(hi2.getInput().get(1), OpOp2.POW)      // do not rewrite B*(A^2)
+						&& hi2.getInput().get(1).getInput().get(1) instanceof LiteralOp  // let tak+* handle it
+						&& ((LiteralOp)hi2.getInput().get(1).getInput().get(1)).getLongValue() == 2)
+					)
 			{
 				baLeft = hi2.getInput().get(0);
 				baRight = hi2.getInput().get(1);


[14/23] systemml git commit: Review comments 3

Posted by mb...@apache.org.
Review comments 3


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/04f692df
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/04f692df
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/04f692df

Branch: refs/heads/master
Commit: 04f692dfcb25a032044dabb7064241073f959300
Parents: de469d2
Author: Dylan Hutchison <dh...@cs.washington.edu>
Authored: Sun Jun 18 16:54:51 2017 -0700
Committer: Dylan Hutchison <dh...@cs.washington.edu>
Committed: Sun Jun 18 17:43:54 2017 -0700

----------------------------------------------------------------------
 .../java/org/apache/sysml/hops/AggUnaryOp.java  |   4 +-
 .../sysml/hops/rewrite/ProgramRewriter.java     |   2 +-
 ...ementwiseMultChainOptimizationChainTest.java | 127 -------------------
 ...iteElementwiseMultChainOptimizationTest.java | 127 +++++++++++++++++++
 .../functions/misc/ZPackageSuite.java           |   2 +-
 5 files changed, 132 insertions(+), 130 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/04f692df/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
index a207831..8e681c1 100644
--- a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
+++ b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
@@ -647,7 +647,7 @@ public class AggUnaryOp extends Hop implements MultiThreadedHop
 			handled = true;
 		} else if (input11 instanceof BinaryOp ) {
 			BinaryOp b11 = (BinaryOp)input11;
-			switch (b11.getOp()) {
+			switch( b11.getOp() ) {
 			case MULT: // A*B*C case
 				in1 = input11.getInput().get(0).constructLops();
 				in2 = input11.getInput().get(1).constructLops();
@@ -664,6 +664,7 @@ public class AggUnaryOp extends Hop implements MultiThreadedHop
 					handled = true;
 				}
 				break;
+			default: break;
 			}
 		} else if( input12 instanceof BinaryOp ) {
 			BinaryOp b12 = (BinaryOp)input12;
@@ -683,6 +684,7 @@ public class AggUnaryOp extends Hop implements MultiThreadedHop
 					handled = true;
 				}
 				break;
+			default: break;
 			}
 		}
 

http://git-wip-us.apache.org/repos/asf/systemml/blob/04f692df/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
index 1053850..7ee3ccb 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
@@ -97,7 +97,7 @@ public class ProgramRewriter
 			if( OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION )
 				_dagRuleSet.add( new RewriteCommonSubexpressionElimination()     );
 			if ( OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES)
-				_dagRuleSet.add( new RewriteElementwiseMultChainOptimization()                              ); //dependency: cse
+				_dagRuleSet.add( new RewriteElementwiseMultChainOptimization()   ); //dependency: cse
 			if( OptimizerUtils.ALLOW_CONSTANT_FOLDING )
 				_dagRuleSet.add( new RewriteConstantFolding()                    ); //dependency: cse
 			if( OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION )

http://git-wip-us.apache.org/repos/asf/systemml/blob/04f692df/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationChainTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationChainTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationChainTest.java
deleted file mode 100644
index e490750..0000000
--- a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationChainTest.java
+++ /dev/null
@@ -1,127 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- * 
- *   http://www.apache.org/licenses/LICENSE-2.0
- * 
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysml.test.integration.functions.misc;
-
-import java.util.HashMap;
-
-import org.apache.sysml.api.DMLScript;
-import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
-import org.apache.sysml.hops.OptimizerUtils;
-import org.apache.sysml.lops.LopProperties.ExecType;
-import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex;
-import org.apache.sysml.test.integration.AutomatedTestBase;
-import org.apache.sysml.test.integration.TestConfiguration;
-import org.apache.sysml.test.utils.TestUtils;
-import org.junit.Assert;
-import org.junit.Test;
-
-/**
- * Test whether `2*X*3*Y*4*X` successfully rewrites to `Y*(X^2)*24`.
- */
-public class RewriteElementwiseMultChainOptimizationChainTest extends AutomatedTestBase
-{
-	private static final String TEST_NAME1 = "RewriteEMultChainOpXYX";
-	private static final String TEST_DIR = "functions/misc/";
-	private static final String TEST_CLASS_DIR = TEST_DIR + RewriteElementwiseMultChainOptimizationChainTest.class.getSimpleName() + "/";
-	
-	private static final int rows = 123;
-	private static final int cols = 321;
-	private static final double eps = Math.pow(10, -10);
-	
-	@Override
-	public void setUp() {
-		TestUtils.clearAssertionInformation();
-		addTestConfiguration( TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) );
-	}
-
-	@Test
-	public void testMatrixMultChainOptNoRewritesCP() {
-		testRewriteMatrixMultChainOp(TEST_NAME1, false, ExecType.CP);
-	}
-	
-	@Test
-	public void testMatrixMultChainOptNoRewritesSP() {
-		testRewriteMatrixMultChainOp(TEST_NAME1, false, ExecType.SPARK);
-	}
-	
-	@Test
-	public void testMatrixMultChainOptRewritesCP() {
-		testRewriteMatrixMultChainOp(TEST_NAME1, true, ExecType.CP);
-	}
-	
-	@Test
-	public void testMatrixMultChainOptRewritesSP() {
-		testRewriteMatrixMultChainOp(TEST_NAME1, true, ExecType.SPARK);
-	}
-
-	private void testRewriteMatrixMultChainOp(String testname, boolean rewrites, ExecType et)
-	{	
-		RUNTIME_PLATFORM platformOld = rtplatform;
-		switch( et ){
-			case MR: rtplatform = RUNTIME_PLATFORM.HADOOP; break;
-			case SPARK: rtplatform = RUNTIME_PLATFORM.SPARK; break;
-			default: rtplatform = RUNTIME_PLATFORM.HYBRID_SPARK; break;
-		}
-		
-		boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
-		if( rtplatform == RUNTIME_PLATFORM.SPARK )
-			DMLScript.USE_LOCAL_SPARK_CONFIG = true;
-		
-		boolean rewritesOld = OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES;
-		OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewrites;
-		
-		try
-		{
-			TestConfiguration config = getTestConfiguration(testname);
-			loadTestConfiguration(config);
-			
-			String HOME = SCRIPT_DIR + TEST_DIR;
-			fullDMLScriptName = HOME + testname + ".dml";
-			programArgs = new String[] { "-explain", "hops", "-stats", "-args", input("X"), input("Y"), output("R") };
-			fullRScriptName = HOME + testname + ".R";
-			rCmd = getRCmd(inputDir(), expectedDir());
-
-			double Xsparsity = 0.8, Ysparsity = 0.6;
-			double[][] X = getRandomMatrix(rows, cols, -1, 1, Xsparsity, 7);
-			double[][] Y = getRandomMatrix(rows, cols, -1, 1, Ysparsity, 3);
-			writeInputMatrixWithMTD("X", X, true);
-			writeInputMatrixWithMTD("Y", Y, true);
-
-			//execute tests
-			runTest(true, false, null, -1); 
-			runRScript(true); 
-			
-			//compare matrices 
-			HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromHDFS("R");
-			HashMap<CellIndex, Double> rfile  = readRMatrixFromFS("R");
-			TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R");
-			
-			//check for presence of power operator, if we did a rewrite
-			if( rewrites ) {
-				Assert.assertTrue(heavyHittersContainsSubString("^2"));
-			}
-		}
-		finally {
-			OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewritesOld;
-			rtplatform = platformOld;
-			DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
-		}
-	}
-}

http://git-wip-us.apache.org/repos/asf/systemml/blob/04f692df/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationTest.java
new file mode 100644
index 0000000..91cb4e0
--- /dev/null
+++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationTest.java
@@ -0,0 +1,127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.test.integration.functions.misc;
+
+import java.util.HashMap;
+
+import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
+import org.apache.sysml.hops.OptimizerUtils;
+import org.apache.sysml.lops.LopProperties.ExecType;
+import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysml.test.integration.AutomatedTestBase;
+import org.apache.sysml.test.integration.TestConfiguration;
+import org.apache.sysml.test.utils.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+
+/**
+ * Test whether `2*X*3*Y*4*X` successfully rewrites to `Y*(X^2)*24`.
+ */
+public class RewriteElementwiseMultChainOptimizationTest extends AutomatedTestBase
+{
+	private static final String TEST_NAME1 = "RewriteEMultChainOpXYX";
+	private static final String TEST_DIR = "functions/misc/";
+	private static final String TEST_CLASS_DIR = TEST_DIR + RewriteElementwiseMultChainOptimizationTest.class.getSimpleName() + "/";
+	
+	private static final int rows = 123;
+	private static final int cols = 321;
+	private static final double eps = Math.pow(10, -10);
+	
+	@Override
+	public void setUp() {
+		TestUtils.clearAssertionInformation();
+		addTestConfiguration( TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) );
+	}
+
+	@Test
+	public void testMatrixMultChainOptNoRewritesCP() {
+		testRewriteMatrixMultChainOp(TEST_NAME1, false, ExecType.CP);
+	}
+	
+	@Test
+	public void testMatrixMultChainOptNoRewritesSP() {
+		testRewriteMatrixMultChainOp(TEST_NAME1, false, ExecType.SPARK);
+	}
+	
+	@Test
+	public void testMatrixMultChainOptRewritesCP() {
+		testRewriteMatrixMultChainOp(TEST_NAME1, true, ExecType.CP);
+	}
+	
+	@Test
+	public void testMatrixMultChainOptRewritesSP() {
+		testRewriteMatrixMultChainOp(TEST_NAME1, true, ExecType.SPARK);
+	}
+
+	private void testRewriteMatrixMultChainOp(String testname, boolean rewrites, ExecType et)
+	{	
+		RUNTIME_PLATFORM platformOld = rtplatform;
+		switch( et ){
+			case MR: rtplatform = RUNTIME_PLATFORM.HADOOP; break;
+			case SPARK: rtplatform = RUNTIME_PLATFORM.SPARK; break;
+			default: rtplatform = RUNTIME_PLATFORM.HYBRID_SPARK; break;
+		}
+		
+		boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+		if( rtplatform == RUNTIME_PLATFORM.SPARK )
+			DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+		
+		boolean rewritesOld = OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES;
+		OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewrites;
+		
+		try
+		{
+			TestConfiguration config = getTestConfiguration(testname);
+			loadTestConfiguration(config);
+			
+			String HOME = SCRIPT_DIR + TEST_DIR;
+			fullDMLScriptName = HOME + testname + ".dml";
+			programArgs = new String[] { "-explain", "hops", "-stats", "-args", input("X"), input("Y"), output("R") };
+			fullRScriptName = HOME + testname + ".R";
+			rCmd = getRCmd(inputDir(), expectedDir());
+
+			double Xsparsity = 0.8, Ysparsity = 0.6;
+			double[][] X = getRandomMatrix(rows, cols, -1, 1, Xsparsity, 7);
+			double[][] Y = getRandomMatrix(rows, cols, -1, 1, Ysparsity, 3);
+			writeInputMatrixWithMTD("X", X, true);
+			writeInputMatrixWithMTD("Y", Y, true);
+
+			//execute tests
+			runTest(true, false, null, -1); 
+			runRScript(true); 
+			
+			//compare matrices 
+			HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromHDFS("R");
+			HashMap<CellIndex, Double> rfile  = readRMatrixFromFS("R");
+			TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R");
+			
+			//check for presence of power operator, if we did a rewrite
+			if( rewrites ) {
+				Assert.assertTrue(heavyHittersContainsSubString("^2"));
+			}
+		}
+		finally {
+			OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewritesOld;
+			rtplatform = platformOld;
+			DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+		}
+	}
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/04f692df/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
----------------------------------------------------------------------
diff --git a/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java b/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
index deea784..860cdbe 100644
--- a/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
+++ b/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
@@ -50,7 +50,7 @@ import org.junit.runners.Suite;
 	ReadAfterWriteTest.class,
 	RewriteCSETransposeScalarTest.class,
 	RewriteCTableToRExpandTest.class,
-	RewriteElementwiseMultChainOptimizationChainTest.class,
+	RewriteElementwiseMultChainOptimizationTest.class,
 	RewriteEliminateAggregatesTest.class,
 	RewriteFuseBinaryOpChainTest.class,
 	RewriteFusedRandTest.class,


[10/23] systemml git commit: Add scalars to Rewrite Emult test

Posted by mb...@apache.org.
Add scalars to Rewrite Emult test

Not sure how to check this in an assert statement


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/737f93b1
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/737f93b1
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/737f93b1

Branch: refs/heads/master
Commit: 737f93b15a96aba31bc6c6da3651be309e3b8b0c
Parents: b94557f
Author: Dylan Hutchison <dh...@cs.washington.edu>
Authored: Sun Jun 11 01:56:07 2017 -0700
Committer: Dylan Hutchison <dh...@cs.washington.edu>
Committed: Sun Jun 18 17:43:41 2017 -0700

----------------------------------------------------------------------
 .../hops/rewrite/RewriteElementwiseMultChainOptimization.java   | 2 +-
 src/main/java/org/apache/sysml/utils/Explain.java               | 4 ++--
 .../misc/RewriteElementwiseMultChainOptimizationChainTest.java  | 4 ++--
 .../integration/functions/ternary/ABATernaryAggregateTest.java  | 5 +----
 src/test/scripts/functions/misc/RewriteEMultChainOpXYX.R        | 4 ++--
 src/test/scripts/functions/misc/RewriteEMultChainOpXYX.dml      | 2 +-
 6 files changed, 9 insertions(+), 12 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/737f93b1/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
index bd873ff..1dd5813 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
@@ -222,7 +222,7 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
 				}
 			}
 		}
-	};
+	}.reversed();
 
 	/**
 	 * Check if a node has a parent that is not in the set of emults. Recursively check children who are also emults.

http://git-wip-us.apache.org/repos/asf/systemml/blob/737f93b1/src/main/java/org/apache/sysml/utils/Explain.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/utils/Explain.java b/src/main/java/org/apache/sysml/utils/Explain.java
index 5cf0548..450c6e5 100644
--- a/src/main/java/org/apache/sysml/utils/Explain.java
+++ b/src/main/java/org/apache/sysml/utils/Explain.java
@@ -76,7 +76,7 @@ public class Explain
 	//internal configuration parameters
 	private static final boolean REPLACE_SPECIAL_CHARACTERS = true;	
 	private static final boolean SHOW_MEM_ABOVE_BUDGET      = true;
-	private static final boolean SHOW_LITERAL_HOPS          = false;
+	private static final boolean SHOW_LITERAL_HOPS          = true;
 	private static final boolean SHOW_DATA_DEPENDENCIES     = true;
 	private static final boolean SHOW_DATA_FLOW_PROPERTIES  = true;
 
@@ -566,7 +566,7 @@ public class Explain
 			childs.append(" (");
 			boolean childAdded = false;
 			for( Hop input : hop.getInput() )
-				if( !(input instanceof LiteralOp) ){
+				if( SHOW_LITERAL_HOPS || !(input instanceof LiteralOp) ){
 					childs.append(childAdded?",":"");
 					childs.append(input.getHopID());
 					childAdded = true;

http://git-wip-us.apache.org/repos/asf/systemml/blob/737f93b1/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationChainTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationChainTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationChainTest.java
index 47b2f0e..e490750 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationChainTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationChainTest.java
@@ -33,7 +33,7 @@ import org.junit.Assert;
 import org.junit.Test;
 
 /**
- * Test whether `A*B*A` successfully rewrites to `(A^2)*B`.
+ * Test whether `2*X*3*Y*4*X` successfully rewrites to `Y*(X^2)*24`.
  */
 public class RewriteElementwiseMultChainOptimizationChainTest extends AutomatedTestBase
 {
@@ -96,7 +96,7 @@ public class RewriteElementwiseMultChainOptimizationChainTest extends AutomatedT
 			fullDMLScriptName = HOME + testname + ".dml";
 			programArgs = new String[] { "-explain", "hops", "-stats", "-args", input("X"), input("Y"), output("R") };
 			fullRScriptName = HOME + testname + ".R";
-			rCmd = getRCmd(inputDir(), expectedDir());			
+			rCmd = getRCmd(inputDir(), expectedDir());
 
 			double Xsparsity = 0.8, Ysparsity = 0.6;
 			double[][] X = getRandomMatrix(rows, cols, -1, 1, Xsparsity, 7);

http://git-wip-us.apache.org/repos/asf/systemml/blob/737f93b1/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java b/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java
index 460829d..12525c9 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/ternary/ABATernaryAggregateTest.java
@@ -368,15 +368,13 @@ public class ABATernaryAggregateTest extends AutomatedTestBase
 		if( rtplatform == RUNTIME_PLATFORM.SPARK )
 			DMLScript.USE_LOCAL_SPARK_CONFIG = true;
 	
-		boolean rewritesOld = OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES,
-				rewritesOldEmult = OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES;
+		boolean rewritesOld = OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES;
 		
 		try {
 			TestConfiguration config = getTestConfiguration(testname);
 			loadTestConfiguration(config);
 			
 			OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewrites;
-			OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewrites;
 
 			String HOME = SCRIPT_DIR + TEST_DIR;
 			fullDMLScriptName = HOME + testname + ".dml";
@@ -412,7 +410,6 @@ public class ABATernaryAggregateTest extends AutomatedTestBase
 			rtplatform = platformOld;
 			DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
 			OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewritesOld;
-			OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewritesOldEmult;
 		}
 	}
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/737f93b1/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.R b/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.R
index 6d94cc8..fec61ae 100644
--- a/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.R
+++ b/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.R
@@ -28,6 +28,6 @@ library("matrixStats")
 X = as.matrix(readMM(paste(args[1], "X.mtx", sep="")))
 Y = as.matrix(readMM(paste(args[1], "Y.mtx", sep="")))
 
-R = X * Y * X;
+R = 2 * X * 3 * Y * 4 * X;
 
-writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep="")); 
+writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep=""));

http://git-wip-us.apache.org/repos/asf/systemml/blob/737f93b1/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.dml b/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.dml
index 3992403..88f252f 100644
--- a/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.dml
+++ b/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.dml
@@ -23,6 +23,6 @@
 X = read($1);
 Y = read($2);
 
-R = X * Y * X;
+R = 2 * X * 3 * Y * 4 * X;
 
 write(R, $3);
\ No newline at end of file


[04/23] systemml git commit: Correct ordering of e-mult chain rewrites.

Posted by mb...@apache.org.
Correct ordering of e-mult chain rewrites.

Sorting scalars, vectors, matrices appropriately and by sparsity (when nnz information is available).


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/8b832f62
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/8b832f62
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/8b832f62

Branch: refs/heads/master
Commit: 8b832f624dd23ba0006672c444cf6f0649a6e753
Parents: ff8c836
Author: Dylan Hutchison <dh...@cs.washington.edu>
Authored: Fri Jun 9 20:48:57 2017 -0700
Committer: Dylan Hutchison <dh...@cs.washington.edu>
Committed: Sun Jun 18 17:43:21 2017 -0700

----------------------------------------------------------------------
 .../apache/sysml/hops/rewrite/RewriteEMult.java | 78 ++++++++++++++++++--
 .../functions/misc/RewriteEMultChainTest.java   |  7 +-
 2 files changed, 74 insertions(+), 11 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/8b832f62/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
index 66da6fa..d483a08 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteEMult.java
@@ -32,6 +32,7 @@ import org.apache.sysml.hops.BinaryOp;
 import org.apache.sysml.hops.Hop;
 import org.apache.sysml.hops.HopsException;
 import org.apache.sysml.hops.LiteralOp;
+import org.apache.sysml.parser.Expression;
 
 import com.google.common.collect.HashMultiset;
 import com.google.common.collect.Multiset;
@@ -125,13 +126,13 @@ public class RewriteEMult extends HopRewriteRule {
 		}
 		// sorted contains all leaves, sorted by data type, stripped from their parents
 
-		// Construct left-deep EMult tree
-		Iterator<Map.Entry<Hop, Integer>> iterator = sorted.entrySet().iterator();
+		// Construct right-deep EMult tree
+		final Iterator<Map.Entry<Hop, Integer>> iterator = sorted.entrySet().iterator();
 		Hop first = constructPower(iterator.next());
 
 		for (int i = 1; i < sorted.size(); i++) {
 			final Hop second = constructPower(iterator.next());
-			first = HopRewriteUtils.createBinary(first, second, Hop.OpOp2.MULT);
+			first = HopRewriteUtils.createBinary(second, first, Hop.OpOp2.MULT);
 		}
 		return first;
 	}
@@ -140,14 +141,75 @@ public class RewriteEMult extends HopRewriteRule {
 		final Hop hop = entry.getKey();
 		final int cnt = entry.getValue();
 		assert(cnt >= 1);
-		if (cnt == 1)
-			return hop;
+		if (cnt == 1) return hop;
 		return HopRewriteUtils.createBinary(hop, new LiteralOp(cnt), Hop.OpOp2.POW);
 	}
 
-	private static Comparator<Hop> compareByDataType = Comparator.comparing(Hop::getDataType)
-			.thenComparing(Hop::getName)
-			.thenComparingInt(Object::hashCode);
+
+
+	// Order: scalars > row vectors > col vectors >
+	//        non-vector matrices ordered by sparsity (higher nnz first, unknown sparsity last) >
+	//        other data types
+	// disambiguate by Hop ID
+	private static final Comparator<Hop> compareByDataType = new Comparator<Hop>() {
+		@Override
+		public final int compare(Hop o1, Hop o2) {
+			int c = Integer.compare(orderDataType[o1.getDataType().ordinal()], orderDataType[o2.getDataType().ordinal()]);
+			if (c != 0) return c;
+
+			// o1 and o2 have the same data type
+			switch (o1.getDataType()) {
+			case SCALAR: return Long.compare(o1.getHopID(), o2.getHopID());
+			case MATRIX:
+				// two matrices; check for vectors
+				if (o1.getDim1() == 1) { // row vector
+						if (o2.getDim1() != 1) return 1; // row vectors are greatest of matrices
+						return compareBySparsityThenId(o1, o2); // both row vectors
+				} else if (o2.getDim1() == 1) { // 2 is row vector; 1 is not
+						return -1; // row vectors are the greatest matrices
+				} else if (o1.getDim2() == 1) { // col vector
+						if (o2.getDim2() != 1) return 1; // col vectors greater than non-vectors
+						return compareBySparsityThenId(o1, o2); // both col vectors
+				} else if (o2.getDim2() == 1) { // 2 is col vector; 1 is not
+						return 1; // col vectors greater than non-vectors
+				} else { // both non-vectors
+						return compareBySparsityThenId(o1, o2);
+				}
+			default:
+				return Long.compare(o1.getHopID(), o2.getHopID());
+			}
+		}
+		private int compareBySparsityThenId(Hop o1, Hop o2) {
+			// the hop with more nnz is first; unknown nnz (-1) last
+			int c = Long.compare(o1.getNnz(), o2.getNnz());
+			if (c != 0) return c;
+			return Long.compare(o1.getHopID(), o2.getHopID());
+		}
+		private final int[] orderDataType;
+		{
+			Expression.DataType[] dtValues = Expression.DataType.values();
+			orderDataType = new int[dtValues.length];
+			for (int i = 0, valuesLength = dtValues.length; i < valuesLength; i++) {
+				switch(dtValues[i]) {
+				case SCALAR:
+					orderDataType[i] = 4;
+					break;
+				case MATRIX:
+					orderDataType[i] = 3;
+					break;
+				case FRAME:
+					orderDataType[i] = 2;
+					break;
+				case OBJECT:
+					orderDataType[i] = 1;
+					break;
+				case UNKNOWN:
+					orderDataType[i] = 0;
+					break;
+				}
+			}
+		}
+	};
 
 	private static boolean checkForeignParent(final Set<BinaryOp> emults, final BinaryOp child) {
 		final ArrayList<Hop> parents = child.getParent();

http://git-wip-us.apache.org/repos/asf/systemml/blob/8b832f62/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java
index e076c95..18ed55d 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java
@@ -99,8 +99,9 @@ public class RewriteEMultChainTest extends AutomatedTestBase
 			fullRScriptName = HOME + testname + ".R";
 			rCmd = getRCmd(inputDir(), expectedDir());			
 
-			double[][] X = getRandomMatrix(rows, cols, -1, 1, 0.97d, 7);
-			double[][] Y = getRandomMatrix(rows, cols, -1, 1, 0.9d, 3);
+			double Xsparsity = 0.8, Ysparsity = 0.6;
+			double[][] X = getRandomMatrix(rows, cols, -1, 1, Xsparsity, 7);
+			double[][] Y = getRandomMatrix(rows, cols, -1, 1, Ysparsity, 3);
 			writeInputMatrixWithMTD("X", X, true);
 			writeInputMatrixWithMTD("Y", Y, true);
 			
@@ -123,5 +124,5 @@ public class RewriteEMultChainTest extends AutomatedTestBase
 			rtplatform = platformOld;
 			DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
 		}
-	}	
+	}
 }


[22/23] systemml git commit: Ensure the `wumm` new pattern does not have foreign parents

Posted by mb...@apache.org.
Ensure the `wumm` new pattern does not have foreign parents


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/14bd65a5
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/14bd65a5
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/14bd65a5

Branch: refs/heads/master
Commit: 14bd65a551b227860faba56ed1633db85d7110f2
Parents: 479b9da
Author: Dylan Hutchison <dh...@cs.washington.edu>
Authored: Thu Jul 13 01:41:51 2017 -0700
Committer: Dylan Hutchison <dh...@cs.washington.edu>
Committed: Thu Jul 13 01:41:51 2017 -0700

----------------------------------------------------------------------
 .../sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/14bd65a5/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index 6246270..09b66de 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -1960,7 +1960,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
 			LOG.debug("Applied simplifyWeightedUnaryMM1 (line "+hi.getBeginLine()+")");	
 		}
 
-		//Pattern 1.5) (W*(U%*%t(V))*2 or 2*(W*(U%*%t(V))
+		//Pattern 2.7) (W*(U%*%t(V))*2 or 2*(W*(U%*%t(V))
 		if( !appliedPattern
 				&& hi instanceof BinaryOp && HopRewriteUtils.isValidOp(((BinaryOp)hi).getOp(), OpOp2.MULT)
 				&& (HopRewriteUtils.isLiteralOfValue(hi.getInput().get(0), 2)
@@ -1974,6 +1974,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
 			}
 
 			if (       HopRewriteUtils.isBinary(nl, OpOp2.MULT)
+					&& nl.getParent().size()==1 // ensure no foreign parents
 					&& HopRewriteUtils.isEqualSize(nl.getInput().get(0), nl.getInput().get(1)) //prevent mv
 					&& nl.getDim2() > 1 //not applied for vector-vector mult
 					&& nl.getInput().get(0).getDataType() == DataType.MATRIX


[17/23] systemml git commit: Merge branch 'rewrite-emult' into rewrite-emult2

Posted by mb...@apache.org.
Merge branch 'rewrite-emult' into rewrite-emult2

# Conflicts:
#	src/main/java/org/apache/sysml/hops/rewrite/HopDagValidator.java
#	src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
#	src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
#	src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationTest.java


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/3c4d777a
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/3c4d777a
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/3c4d777a

Branch: refs/heads/master
Commit: 3c4d777a5a15ed59681547e91b84c8812d3420fc
Parents: b67f186 999fdfb
Author: Dylan Hutchison <dh...@cs.washington.edu>
Authored: Tue Jul 11 22:26:59 2017 -0700
Committer: Dylan Hutchison <dh...@cs.washington.edu>
Committed: Tue Jul 11 22:37:36 2017 -0700

----------------------------------------------------------------------
 .../sysml/hops/rewrite/ProgramRewriter.java     |   4 +-
 ...RewriteElementwiseMultChainOptimization.java | 184 +++++++++++++------
 ...ElementwiseMultChainOptimizationAllTest.java | 134 ++++++++++++++
 ...iteElementwiseMultChainOptimizationTest.java |   4 +-
 .../functions/misc/RewriteEMultChainOpAll.R     |  37 ++++
 .../functions/misc/RewriteEMultChainOpAll.dml   |  31 ++++
 6 files changed, 334 insertions(+), 60 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/3c4d777a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
----------------------------------------------------------------------