You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2017/06/21 07:05:25 UTC

[2/2] systemml git commit: [HOTFIX][SYSTEMML-1663] Fix and disable element-wise mult chain rewrite

[HOTFIX][SYSTEMML-1663] Fix and disable element-wise mult chain rewrite

This patch fixes the custom hop comparator to find an ordering of
element-wise multiplication chains (scalars, vectors, matrices), which
fixes the test issue of PR549. Due to additional issues that could cause
result incorrectness or runtime errors, I'm temporarily disabling this
rewrite and related tests. 

Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/a5c834b2
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/a5c834b2
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/a5c834b2

Branch: refs/heads/master
Commit: a5c834b27da9cfeffe0ad6e606c43fe3246831d2
Parents: 9e7ce7b
Author: Matthias Boehm <mb...@gmail.com>
Authored: Wed Jun 21 00:05:32 2017 -0700
Committer: Matthias Boehm <mb...@gmail.com>
Committed: Wed Jun 21 00:05:32 2017 -0700

----------------------------------------------------------------------
 .../sysml/hops/rewrite/ProgramRewriter.java     |  6 ++---
 ...RewriteElementwiseMultChainOptimization.java | 27 ++++++++++++--------
 ...iteElementwiseMultChainOptimizationTest.java |  4 ++-
 3 files changed, 23 insertions(+), 14 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/a5c834b2/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
index 7ee3ccb..92d31c2 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
@@ -96,8 +96,8 @@ public class ProgramRewriter
 			_dagRuleSet.add(     new RewriteRemoveUnnecessaryCasts()             );		
 			if( OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION )
 				_dagRuleSet.add( new RewriteCommonSubexpressionElimination()     );
-			if ( OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES)
-				_dagRuleSet.add( new RewriteElementwiseMultChainOptimization()   ); //dependency: cse
+			//if ( OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES)
+			//	_dagRuleSet.add( new RewriteElementwiseMultChainOptimization()   ); //dependency: cse
 			if( OptimizerUtils.ALLOW_CONSTANT_FOLDING )
 				_dagRuleSet.add( new RewriteConstantFolding()                    ); //dependency: cse
 			if( OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION )
@@ -108,7 +108,7 @@ public class ProgramRewriter
 				_dagRuleSet.add( new RewriteIndexingVectorization()              ); //dependency: cse, simplifications
 			_dagRuleSet.add( new RewriteInjectSparkPReadCheckpointing()          ); //dependency: reblock
 			
-			//add statment block rewrite rules
+			//add statement block rewrite rules
  			if( OptimizerUtils.ALLOW_BRANCH_REMOVAL )			
 				_sbRuleSet.add(  new RewriteRemoveUnnecessaryBranches()          ); //dependency: constant folding		
  			if( OptimizerUtils.ALLOW_SPLIT_HOP_DAGS )

http://git-wip-us.apache.org/repos/asf/systemml/blob/a5c834b2/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
index 9ca0932..fe2a5d0 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteElementwiseMultChainOptimization.java
@@ -42,7 +42,7 @@ import com.google.common.collect.Multiset;
  *
  * Rewrite a chain of element-wise multiply hops that contain identical elements.
  * For example `(B * A) * B` is rewritten to `A * (B^2)` (or `(B^2) * A`), where `^` is element-wise power.
- * The order of the multiplicands depends on their data types, dimentions (matrix or vector), and sparsity.
+ * The order of the multiplicands depends on their data types, dimensions (matrix or vector), and sparsity.
  *
  * Does not rewrite in the presence of foreign parents in the middle of the e-wise multiply chain,
  * since foreign parents may rely on the individual results.
@@ -136,6 +136,8 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
 		// sorted contains all leaves, sorted by data type, stripped from their parents
 
 		// Construct right-deep EMult tree
+		// TODO compile binary outer mult for transition from row and column vectors to matrices
+		// TODO compile subtree for column vectors to avoid blow-up of intermediates on row-col vector transition
 		final Iterator<Map.Entry<Hop, Integer>> iterator = sorted.entrySet().iterator();
 		Hop first = constructPower(iterator.next());
 
@@ -160,13 +162,15 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
 	}
 
 	/**
-	 * A Comparator that orders Hops by their data type, dimention, and sparsity.
+	 * A Comparator that orders Hops by their data type, dimension, and sparsity.
 	 * The order is as follows:
 	 * 		scalars > row vectors > col vectors >
 	 *      non-vector matrices ordered by sparsity (higher nnz first, unknown sparsity last) >
 	 *      other data types.
 	 * Disambiguate by Hop ID.
 	 */
+	//TODO replace by ComparableHop wrapper around hop that implements equals and compareTo
+	//in order to ensure comparisons that are 'consistent with equals'
 	private static final Comparator<Hop> compareByDataType = new Comparator<Hop>() {
 		private final int[] orderDataType = new int[Expression.DataType.values().length];
 		{
@@ -190,17 +194,17 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
 			case MATRIX:
 				// two matrices; check for vectors
 				if (o1.getDim1() == 1) { // row vector
-						if (o2.getDim1() != 1) return 1; // row vectors are greatest of matrices
-						return compareBySparsityThenId(o1, o2); // both row vectors
+					if (o2.getDim1() != 1) return 1; // row vectors are greatest of matrices
+					return compareBySparsityThenId(o1, o2); // both row vectors
 				} else if (o2.getDim1() == 1) { // 2 is row vector; 1 is not
-						return -1; // row vectors are the greatest matrices
+					return -1; // row vectors are the greatest matrices
 				} else if (o1.getDim2() == 1) { // col vector
-						if (o2.getDim2() != 1) return 1; // col vectors greater than non-vectors
-						return compareBySparsityThenId(o1, o2); // both col vectors
+					if (o2.getDim2() != 1) return 1; // col vectors greater than non-vectors
+					return compareBySparsityThenId(o1, o2); // both col vectors
 				} else if (o2.getDim2() == 1) { // 2 is col vector; 1 is not
-						return 1; // col vectors greater than non-vectors
+					return -1; // col vectors greater than non-vectors
 				} else { // both non-vectors
-						return compareBySparsityThenId(o1, o2);
+					return compareBySparsityThenId(o1, o2);
 				}
 			default:
 				return Long.compare(o1.getHopID(), o2.getHopID());
@@ -243,7 +247,10 @@ public class RewriteElementwiseMultChainOptimization extends HopRewriteRule {
 	private static void findEMultsAndLeaves(final BinaryOp root, final Set<BinaryOp> emults, final Multiset<Hop> leaves) {
 		// Because RewriteCommonSubexpressionElimination already ran, it is safe to compare by equality.
 		emults.add(root);
-
+		
+		// TODO proper handling of DAGs (avoid collecting the same leaf multiple times)
+		// TODO exclude hops with unknown dimensions and move rewrites to dynamic rewrites 
+		
 		final ArrayList<Hop> inputs = root.getInput();
 		final Hop left = inputs.get(0), right = inputs.get(1);
 

http://git-wip-us.apache.org/repos/asf/systemml/blob/a5c834b2/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationTest.java
index 91cb4e0..b16fa3e 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationTest.java
@@ -50,7 +50,7 @@ public class RewriteElementwiseMultChainOptimizationTest extends AutomatedTestBa
 		TestUtils.clearAssertionInformation();
 		addTestConfiguration( TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) );
 	}
-
+	
 	@Test
 	public void testMatrixMultChainOptNoRewritesCP() {
 		testRewriteMatrixMultChainOp(TEST_NAME1, false, ExecType.CP);
@@ -61,6 +61,7 @@ public class RewriteElementwiseMultChainOptimizationTest extends AutomatedTestBa
 		testRewriteMatrixMultChainOp(TEST_NAME1, false, ExecType.SPARK);
 	}
 	
+	/* TODO enable together with RewriteElementwiseMultChainOptimization
 	@Test
 	public void testMatrixMultChainOptRewritesCP() {
 		testRewriteMatrixMultChainOp(TEST_NAME1, true, ExecType.CP);
@@ -70,6 +71,7 @@ public class RewriteElementwiseMultChainOptimizationTest extends AutomatedTestBa
 	public void testMatrixMultChainOptRewritesSP() {
 		testRewriteMatrixMultChainOp(TEST_NAME1, true, ExecType.SPARK);
 	}
+	*/
 
 	private void testRewriteMatrixMultChainOp(String testname, boolean rewrites, ExecType et)
 	{