You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2019/03/17 19:23:47 UTC

[systemml] branch master updated (881f606 -> f42dfb3)

This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a change to branch master
in repository https://gitbox.apache.org/repos/asf/systemml.git.


    from 881f606  [MINOR] Provide a more informative error message when the dimensions don't match during the validate phase
     new 4a38a47  [MINOR] Fix unnecessary warnings (unnecessary imports)
     new f42dfb3  [SYSTEMML-2521] New rewrite for sparsity-aware matrix product chains

The 2 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails.  The revisions
listed as "add" were already present in the repository and have only
been added to this reference.


Summary of changes:
 .../org/apache/sysml/api/ScriptExecutorUtils.java  |   1 -
 .../apache/sysml/api/mlcontext/ScriptExecutor.java |   1 -
 .../java/org/apache/sysml/hops/FunctionOp.java     |   1 -
 .../sysml/hops/estim/EstimatorMatrixHistogram.java |   2 +-
 .../sysml/hops/rewrite/ProgramRewriteStatus.java   |  16 ++-
 .../RewriteMatrixMultChainOptimization.java        |  88 ++++++------
 .../RewriteMatrixMultChainOptimizationSparse.java  | 157 +++++++++++++++++++++
 .../java/org/apache/sysml/utils/Statistics.java    |   1 -
 .../org/apache/sysml/test/gpu/LstmCPUTest.java     |   2 -
 .../functions/unary/matrix/AbsTest.java            |   2 -
 .../functions/unary/matrix/NegationTest.java       |   2 -
 .../functions/unary/matrix/SinTest.java            |   2 -
 .../functions/unary/matrix/TanTest.java            |   2 -
 13 files changed, 215 insertions(+), 62 deletions(-)
 create mode 100644 src/main/java/org/apache/sysml/hops/rewrite/RewriteMatrixMultChainOptimizationSparse.java


[systemml] 01/02: [MINOR] Fix unnecessary warnings (unnecessary imports)

Posted by mb...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemml.git

commit 4a38a4789302741965f49b4dd559a7078d94eb69
Author: Matthias Boehm <mb...@gmail.com>
AuthorDate: Sun Mar 17 12:09:33 2019 +0100

    [MINOR] Fix unnecessary warnings (unnecessary imports)
---
 src/main/java/org/apache/sysml/api/ScriptExecutorUtils.java             | 1 -
 src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java        | 1 -
 src/main/java/org/apache/sysml/hops/FunctionOp.java                     | 1 -
 src/main/java/org/apache/sysml/utils/Statistics.java                    | 1 -
 src/test/java/org/apache/sysml/test/gpu/LstmCPUTest.java                | 2 --
 .../apache/sysml/test/integration/functions/unary/matrix/AbsTest.java   | 2 --
 .../sysml/test/integration/functions/unary/matrix/NegationTest.java     | 2 --
 .../apache/sysml/test/integration/functions/unary/matrix/SinTest.java   | 2 --
 .../apache/sysml/test/integration/functions/unary/matrix/TanTest.java   | 2 --
 9 files changed, 14 deletions(-)

diff --git a/src/main/java/org/apache/sysml/api/ScriptExecutorUtils.java b/src/main/java/org/apache/sysml/api/ScriptExecutorUtils.java
index 0d072e5..c9d1a5d 100644
--- a/src/main/java/org/apache/sysml/api/ScriptExecutorUtils.java
+++ b/src/main/java/org/apache/sysml/api/ScriptExecutorUtils.java
@@ -19,7 +19,6 @@
 
 package org.apache.sysml.api;
 
-import java.io.IOException;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.HashSet;
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java b/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java
index 7bda306..8ecd962 100644
--- a/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java
+++ b/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java
@@ -38,7 +38,6 @@ import org.apache.sysml.conf.DMLConfig;
 import org.apache.sysml.conf.DMLOptions;
 import org.apache.sysml.hops.OptimizerUtils;
 import org.apache.sysml.parser.DMLProgram;
-import org.apache.sysml.parser.DMLTranslator;
 import org.apache.sysml.parser.ParseException;
 import org.apache.sysml.parser.ParserFactory;
 import org.apache.sysml.parser.ParserWrapper;
diff --git a/src/main/java/org/apache/sysml/hops/FunctionOp.java b/src/main/java/org/apache/sysml/hops/FunctionOp.java
index dedbad6..534c0a0 100644
--- a/src/main/java/org/apache/sysml/hops/FunctionOp.java
+++ b/src/main/java/org/apache/sysml/hops/FunctionOp.java
@@ -22,7 +22,6 @@ package org.apache.sysml.hops;
 import java.util.ArrayList;
 import java.util.List;
 
-import org.apache.sysml.conf.ConfigurationManager;
 import org.apache.sysml.lops.FunctionCallCP;
 import org.apache.sysml.lops.FunctionCallCPSingle;
 import org.apache.sysml.lops.Lop;
diff --git a/src/main/java/org/apache/sysml/utils/Statistics.java b/src/main/java/org/apache/sysml/utils/Statistics.java
index 656de32..a2afae0 100644
--- a/src/main/java/org/apache/sysml/utils/Statistics.java
+++ b/src/main/java/org/apache/sysml/utils/Statistics.java
@@ -32,7 +32,6 @@ import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.atomic.DoubleAdder;
 import java.util.concurrent.atomic.LongAdder;
 
-import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.conf.ConfigurationManager;
 import org.apache.sysml.conf.DMLConfig;
 import org.apache.sysml.hops.OptimizerUtils;
diff --git a/src/test/java/org/apache/sysml/test/gpu/LstmCPUTest.java b/src/test/java/org/apache/sysml/test/gpu/LstmCPUTest.java
index 5c93bca..4c4ab74 100644
--- a/src/test/java/org/apache/sysml/test/gpu/LstmCPUTest.java
+++ b/src/test/java/org/apache/sysml/test/gpu/LstmCPUTest.java
@@ -23,8 +23,6 @@ import java.util.Arrays;
 import java.util.HashMap;
 import java.util.List;
 
-import org.apache.sysml.runtime.instructions.gpu.DnnGPUInstruction;
-import org.apache.sysml.runtime.instructions.gpu.DnnGPUInstruction.LstmOperator;
 import org.apache.sysml.test.utils.TestUtils;
 import org.junit.Test;
 
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/unary/matrix/AbsTest.java b/src/test/java/org/apache/sysml/test/integration/functions/unary/matrix/AbsTest.java
index a3027d6..6b61066 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/unary/matrix/AbsTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/unary/matrix/AbsTest.java
@@ -20,8 +20,6 @@
 package org.apache.sysml.test.integration.functions.unary.matrix;
 
 import org.junit.Test;
-import org.apache.sysml.api.DMLScript;
-import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
 import org.apache.sysml.test.integration.AutomatedTestBase;
 import org.apache.sysml.test.integration.TestConfiguration;
 
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/unary/matrix/NegationTest.java b/src/test/java/org/apache/sysml/test/integration/functions/unary/matrix/NegationTest.java
index c2613c2..6b2000a 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/unary/matrix/NegationTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/unary/matrix/NegationTest.java
@@ -20,8 +20,6 @@
 package org.apache.sysml.test.integration.functions.unary.matrix;
 
 import org.junit.Test;
-import org.apache.sysml.api.DMLScript;
-import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
 import org.apache.sysml.test.integration.AutomatedTestBase;
 import org.apache.sysml.test.integration.TestConfiguration;
 
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/unary/matrix/SinTest.java b/src/test/java/org/apache/sysml/test/integration/functions/unary/matrix/SinTest.java
index b5f7c26..523bf7c 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/unary/matrix/SinTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/unary/matrix/SinTest.java
@@ -20,8 +20,6 @@
 package org.apache.sysml.test.integration.functions.unary.matrix;
 
 import org.junit.Test;
-import org.apache.sysml.api.DMLScript;
-import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
 import org.apache.sysml.test.integration.AutomatedTestBase;
 import org.apache.sysml.test.integration.TestConfiguration;
 
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/unary/matrix/TanTest.java b/src/test/java/org/apache/sysml/test/integration/functions/unary/matrix/TanTest.java
index 497e393..0180796 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/unary/matrix/TanTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/unary/matrix/TanTest.java
@@ -20,8 +20,6 @@
 package org.apache.sysml.test.integration.functions.unary.matrix;
 
 import org.junit.Test;
-import org.apache.sysml.api.DMLScript;
-import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
 import org.apache.sysml.test.integration.AutomatedTestBase;
 import org.apache.sysml.test.integration.TestConfiguration;
 


[systemml] 02/02: [SYSTEMML-2521] New rewrite for sparsity-aware matrix product chains

Posted by mb...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemml.git

commit f42dfb358ac24b6633d01dd181b51d458cd1bbe7
Author: Matthias Boehm <mb...@gmail.com>
AuthorDate: Sun Mar 17 20:23:33 2019 +0100

    [SYSTEMML-2521] New rewrite for sparsity-aware matrix product chains
    
    This patch introduces a new dynamic rewrite for sparsity-aware matrix
    multiplication chain optimization. For estimating the sparsity of
    intermediates, we use the existing MNC sparsity estimator.
    
    While this rewrite does find the optimal plan in case of perfect
    estimates, it currently requires access to all input matrices of the mm
    chain and these inputs need to fit into CP memory. Accordingly, this
    rewrite is still disabled by default.
---
 .../sysml/hops/estim/EstimatorMatrixHistogram.java |   2 +-
 .../sysml/hops/rewrite/ProgramRewriteStatus.java   |  16 ++-
 .../RewriteMatrixMultChainOptimization.java        |  88 ++++++------
 .../RewriteMatrixMultChainOptimizationSparse.java  | 157 +++++++++++++++++++++
 4 files changed, 215 insertions(+), 48 deletions(-)

diff --git a/src/main/java/org/apache/sysml/hops/estim/EstimatorMatrixHistogram.java b/src/main/java/org/apache/sysml/hops/estim/EstimatorMatrixHistogram.java
index 5f1abff..b079a7e 100644
--- a/src/main/java/org/apache/sysml/hops/estim/EstimatorMatrixHistogram.java
+++ b/src/main/java/org/apache/sysml/hops/estim/EstimatorMatrixHistogram.java
@@ -59,7 +59,7 @@ public class EstimatorMatrixHistogram extends SparsityEstimator
 		return estim(root, true);
 	}
 	
-	private MatrixCharacteristics estim(MMNode root, boolean topLevel) {
+	public MatrixCharacteristics estim(MMNode root, boolean topLevel) {
 		//NOTE: not estimateInputs due to handling of topLevel
 		MatrixHistogram h1 = getCachedSynopsis(root.getLeft());
 		MatrixHistogram h2 = getCachedSynopsis(root.getRight());
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriteStatus.java b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriteStatus.java
index 552a598..a622948 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriteStatus.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriteStatus.java
@@ -19,9 +19,10 @@
 
 package org.apache.sysml.hops.rewrite;
 
+import org.apache.sysml.runtime.controlprogram.LocalVariableMap;
+
 public class ProgramRewriteStatus 
 {
-	
 	//status of applied rewrites
 	private boolean _rmBranches = false; //removed branches
 	private int _blkSize = -1;
@@ -29,14 +30,19 @@ public class ProgramRewriteStatus
 	
 	//current context
 	private boolean _inParforCtx = false;
+	private LocalVariableMap _vars = null;
 	
-	public ProgramRewriteStatus()
-	{
+	public ProgramRewriteStatus() {
 		_rmBranches = false;
 		_inParforCtx = false;
 		_injectCheckpoints = false;
 	}
 	
+	public ProgramRewriteStatus(LocalVariableMap vars) {
+		this();
+		_vars = vars;
+	}
+	
 	public void setRemovedBranches(){
 		_rmBranches = true;
 	}
@@ -68,4 +74,8 @@ public class ProgramRewriteStatus
 	public boolean getInjectedCheckpoints(){
 		return _injectCheckpoints;
 	}
+	
+	public LocalVariableMap getVariables() {
+		return _vars;
+	}
 }
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteMatrixMultChainOptimization.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteMatrixMultChainOptimization.java
index 91033c4..cdb1e12 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteMatrixMultChainOptimization.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteMatrixMultChainOptimization.java
@@ -35,14 +35,16 @@ import org.apache.sysml.utils.Explain;
 
 /**
  * Rule: Determine the optimal order of execution for a chain of
- * matrix multiplications Solution: Classic Dynamic Programming
- * Approach Currently, the approach based only on matrix dimensions
+ * matrix multiplications 
+ * 
+ * Solution: Classic Dynamic Programming
+ * Approach: Currently, the approach based only on matrix dimensions
  * Goal: To reduce the number of computations in the run-time
  * (map-reduce) layer
  */
 public class RewriteMatrixMultChainOptimization extends HopRewriteRule
 {
-	private static final Log LOG = LogFactory.getLog(RewriteMatrixMultChainOptimization.class.getName());
+	protected static final Log LOG = LogFactory.getLog(RewriteMatrixMultChainOptimization.class.getName());
 	private static final boolean LDEBUG = false;
 	
 	static {
@@ -61,7 +63,7 @@ public class RewriteMatrixMultChainOptimization extends HopRewriteRule
 
 		// Find the optimal order for the chain whose result is the current HOP
 		for( Hop h : roots ) 
-			rule_OptimizeMMChains(h);
+			rule_OptimizeMMChains(h, state);
 		
 		return roots;
 	}
@@ -73,7 +75,7 @@ public class RewriteMatrixMultChainOptimization extends HopRewriteRule
 			return null;
 
 		// Find the optimal order for the chain whose result is the current HOP
-		rule_OptimizeMMChains(root);
+		rule_OptimizeMMChains(root, state);
 		
 		return root;
 	}
@@ -84,7 +86,7 @@ public class RewriteMatrixMultChainOptimization extends HopRewriteRule
 	 * 
 	 * @param hop high-level operator
 	 */
-	private void rule_OptimizeMMChains(Hop hop) 
+	private void rule_OptimizeMMChains(Hop hop, ProgramRewriteStatus state) 
 	{
 		if( hop.isVisited() )
 			return;
@@ -94,11 +96,11 @@ public class RewriteMatrixMultChainOptimization extends HopRewriteRule
 		{
 			// Try to find and optimize the chain in which current Hop is the
 			// last operator
-			optimizeMMChain(hop);
+			prepAndOptimizeMMChain(hop, state);
 		}
 		
 		for( Hop hi : hop.getInput() )
-			rule_OptimizeMMChains(hi);
+			rule_OptimizeMMChains(hi, state);
 
 		hop.setVisited();
 	}
@@ -113,7 +115,7 @@ public class RewriteMatrixMultChainOptimization extends HopRewriteRule
 	 * 
 	 * @param hop high-level operator
 	 */
-	private void optimizeMMChain( Hop hop )
+	private void prepAndOptimizeMMChain( Hop hop, ProgramRewriteStatus state )
 	{
 		if( LOG.isTraceEnabled() ) {
 			LOG.trace("MM Chain Optimization for HOP: (" + hop.getClass().getSimpleName()
@@ -149,18 +151,15 @@ public class RewriteMatrixMultChainOptimization extends HopRewriteRule
 			 *    (either within chain or outside the chain)
 			 */
 			
-			if (    HopRewriteUtils.isMatrixMultiply(h)
-			     && !((AggBinaryOp)hop).hasLeftPMInput() && !h.isVisited() ) 
+			if ( HopRewriteUtils.isMatrixMultiply(h)
+				&& !((AggBinaryOp)hop).hasLeftPMInput() && !h.isVisited() ) 
 			{
 				// check if the output of "h" is used at multiple places. If yes, it can
 				// not be expanded.
-				if( h.getParent().size() > 1 || inputCount(h.getParent().get(0), h) > 1 ) {
-					expandable = false;
+				expandable = !(h.getParent().size() > 1 
+					|| inputCount(h.getParent().get(0), h) > 1);
+				if( !expandable )
 					break;
-				}
-				else {
-					expandable = true;
-				}
 			}
 
 			h.setVisited();
@@ -189,30 +188,31 @@ public class RewriteMatrixMultChainOptimization extends HopRewriteRule
 			}
 		}
 
-		if( mmChain.size() == 2 ) {
-			// If the chain size is 2, then there is nothing to optimize.
-			return;
-		} 
-		else 
-		{
-			// Step 2: construct dims array
-			double[] dimsArray = new double[mmChain.size() + 1];
-			boolean dimsKnown = getDimsArray( hop, mmChain, dimsArray );
+		//core mmchain optimization (potentially overridden)
+		if( mmChain.size() == 2 ) 
+			return; //nothing to optimize
+		else
+			optimizeMMChain(hop, mmChain, mmOperators, state);
+	}
+	
+	protected void optimizeMMChain(Hop hop, ArrayList<Hop> mmChain, ArrayList<Hop> mmOperators, ProgramRewriteStatus state) {
+		// Step 2: construct dims array
+		double[] dimsArray = new double[mmChain.size() + 1];
+		boolean dimsKnown = getDimsArray( hop, mmChain, dimsArray );
+		
+		if( dimsKnown ) {
+			// Step 3: clear the links among Hops within the identified chain
+			clearLinksWithinChain ( hop, mmOperators );
 			
-			if( dimsKnown ) {
-				// Step 3: clear the links among Hops within the identified chain
-				clearLinksWithinChain ( hop, mmOperators );
-				
-				// Step 4: Find the optimal ordering via dynamic programming.
-				
-				// Invoke Dynamic Programming
-				int size = mmChain.size();
-				int[][] split = mmChainDP(dimsArray, mmChain.size());
-				
-				 // Step 5: Relink the hops using the optimal ordering (split[][]) found from DP.
-				LOG.trace("Optimal MM Chain: ");
-				mmChainRelinkHops(mmOperators.get(0), 0, size - 1, mmChain, mmOperators, 1, split, 1);
-			}
+			// Step 4: Find the optimal ordering via dynamic programming.
+			
+			// Invoke Dynamic Programming
+			int size = mmChain.size();
+			int[][] split = mmChainDP(dimsArray, mmChain.size());
+			
+			 // Step 5: Relink the hops using the optimal ordering (split[][]) found from DP.
+			LOG.trace("Optimal MM Chain: ");
+			mmChainRelinkHops(mmOperators.get(0), 0, size - 1, mmChain, mmOperators, 1, split, 1);
 		}
 	}
 	
@@ -244,7 +244,7 @@ public class RewriteMatrixMultChainOptimization extends HopRewriteRule
 				{
 					//recursive cost computation
 					double cost = dpMatrix[i][k] + dpMatrix[k + 1][j] 
-							  + (dimArray[i] * dimArray[k + 1] * dimArray[j + 1]);
+						+ (dimArray[i] * dimArray[k + 1] * dimArray[j + 1]);
 					
 					//prune suboptimal
 					if( cost < dpMatrix[i][j] ) {
@@ -271,7 +271,7 @@ public class RewriteMatrixMultChainOptimization extends HopRewriteRule
 	 * three Hops in mmChain (B,C,D), and two Hops in mmOperators (one for each
 	 * %*%) .
 	 */
-	private void mmChainRelinkHops(Hop h, int i, int j, ArrayList<Hop> mmChain, ArrayList<Hop> mmOperators,
+	protected final void mmChainRelinkHops(Hop h, int i, int j, ArrayList<Hop> mmChain, ArrayList<Hop> mmOperators,
 			int opIndex, int[][] split, int level) 
 	{
 		//single matrix - end of recursion
@@ -320,7 +320,7 @@ public class RewriteMatrixMultChainOptimization extends HopRewriteRule
 		}
 	}
 
-	private static void clearLinksWithinChain( Hop hop, ArrayList<Hop> operators ) 
+	protected static void clearLinksWithinChain( Hop hop, ArrayList<Hop> operators ) 
 	{
 		for( int i=0; i < operators.size(); i++ ) {
 			Hop op = operators.get(i);
@@ -347,7 +347,7 @@ public class RewriteMatrixMultChainOptimization extends HopRewriteRule
 	 * @param dimArray dimension array
 	 * @return true if all dimensions known
 	 */
-	private static boolean getDimsArray( Hop hop, ArrayList<Hop> chain, double[] dimsArray ) 
+	protected static boolean getDimsArray( Hop hop, ArrayList<Hop> chain, double[] dimsArray )
 	{
 		boolean dimsKnown = true;
 		
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteMatrixMultChainOptimizationSparse.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteMatrixMultChainOptimizationSparse.java
new file mode 100644
index 0000000..d30d87b
--- /dev/null
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteMatrixMultChainOptimizationSparse.java
@@ -0,0 +1,157 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.hops.rewrite;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+
+import org.apache.sysml.hops.Hop;
+import org.apache.sysml.hops.HopsException;
+import org.apache.sysml.hops.Hop.DataOpTypes;
+import org.apache.sysml.hops.estim.MMNode;
+import org.apache.sysml.hops.estim.EstimatorMatrixHistogram;
+import org.apache.sysml.hops.estim.EstimatorMatrixHistogram.MatrixHistogram;
+import org.apache.sysml.hops.estim.SparsityEstimator.OpCode;
+import org.apache.sysml.runtime.controlprogram.LocalVariableMap;
+import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysml.runtime.instructions.cp.Data;
+import org.apache.sysml.runtime.matrix.data.MatrixBlock;
+
+/**
+ * Rule: Determine the optimal order of execution for a chain of
+ * matrix multiplications 
+ * 
+ * Solution: Classic Dynamic Programming
+ * Approach: Currently, the approach based only on matrix dimensions
+ * and sparsity estimates using the MNC sketch
+ * Goal: To reduce the number of computations in the run-time
+ * (map-reduce) layer
+ */
+public class RewriteMatrixMultChainOptimizationSparse extends RewriteMatrixMultChainOptimization
+{
+	@Override
+	protected void optimizeMMChain(Hop hop, ArrayList<Hop> mmChain, ArrayList<Hop> mmOperators, ProgramRewriteStatus state) {
+		// Step 2: construct dims array and input matrices
+		double[] dimsArray = new double[mmChain.size() + 1];
+		boolean dimsKnown = getDimsArray( hop, mmChain, dimsArray );
+		MMNode[] sketchArray = new MMNode[mmChain.size() + 1];
+		boolean inputsAvail = getInputMatrices(hop, mmChain, sketchArray, state);
+		
+		if( dimsKnown && inputsAvail ) {
+			// Step 3: clear the links among Hops within the identified chain
+			clearLinksWithinChain ( hop, mmOperators );
+			
+			// Step 4: Find the optimal ordering via dynamic programming.
+			
+			// Invoke Dynamic Programming
+			int size = mmChain.size();
+			int[][] split = mmChainDPSparse(dimsArray, sketchArray, mmChain.size());
+			
+			 // Step 5: Relink the hops using the optimal ordering (split[][]) found from DP.
+			LOG.trace("Optimal MM Chain: ");
+			mmChainRelinkHops(mmOperators.get(0), 0, size - 1, mmChain, mmOperators, 1, split, 1);
+		}
+	}
+	
+	/**
+	 * mmChainDP(): Core method to perform dynamic programming on a given array
+	 * of matrix dimensions.
+	 * 
+	 * Thomas H. Cormen, Charles E. Leiserson, Ronald L. Rivest, Clifford Stein
+	 * Introduction to Algorithms, Third Edition, MIT Press, page 395.
+	 */
+	private static int[][] mmChainDPSparse(double[] dimArray, MMNode[] sketchArray, int size) 
+	{
+		double[][] dpMatrix = new double[size][size]; //min cost table
+		MMNode[][] dpMatrixS = new MMNode[size][size]; //min sketch table
+		int[][] split = new int[size][size]; //min cost index table
+
+		//init minimum costs for chains of length 1
+		for( int i = 0; i < size; i++ ) {
+			Arrays.fill(dpMatrix[i], 0);
+			Arrays.fill(split[i], -1);
+			dpMatrixS[i][i] = sketchArray[i];
+		}
+
+		//compute cost-optimal chains for increasing chain sizes 
+		EstimatorMatrixHistogram estim = new EstimatorMatrixHistogram(true);
+		for( int l = 2; l <= size; l++ ) { // chain length
+			for( int i = 0; i < size - l + 1; i++ ) {
+				int j = i + l - 1;
+				// find cost of (i,j)
+				dpMatrix[i][j] = Double.MAX_VALUE;
+				for( int k = i; k <= j - 1; k++ ) 
+				{
+					//construct estimation nodes (w/ lazy propagation and memoization)
+					MMNode tmp = new MMNode(dpMatrixS[i][k], dpMatrixS[k+1][j], OpCode.MM);
+					estim.estim(tmp, false);
+					MatrixHistogram lhs = (MatrixHistogram) dpMatrixS[i][k].getSynopsis();
+					MatrixHistogram rhs = (MatrixHistogram) dpMatrixS[k+1][j].getSynopsis();
+					
+					//recursive cost computation
+					double cost = dpMatrix[i][k] + dpMatrix[k + 1][j] 
+						+ dotProduct(lhs.getColCounts(), rhs.getRowCounts());
+					
+					//prune suboptimal
+					if( cost < dpMatrix[i][j] ) {
+						dpMatrix[i][j] = cost;
+						dpMatrixS[i][j] = tmp;
+						split[i][j] = k;
+					}
+				}
+
+				if( LOG.isTraceEnabled() ){
+					LOG.trace("mmchainopt [i="+(i+1)+",j="+(j+1)+"]: costs = "+dpMatrix[i][j]+", split = "+(split[i][j]+1));
+				}
+			}
+		}
+
+		return split;
+	}
+	
+	private boolean getInputMatrices(Hop hop, ArrayList<Hop> chain, MMNode[] sketchArray, ProgramRewriteStatus state) {
+		boolean inputsAvail = true;
+		LocalVariableMap vars = state.getVariables();
+		
+		for( int i=0; i<chain.size(); i++ ) {
+			inputsAvail &= HopRewriteUtils.isData(chain.get(0), DataOpTypes.TRANSIENTREAD);
+			if( inputsAvail )
+				sketchArray[i] = new MMNode(getMatrix(chain.get(i).getName(), vars));
+			else 
+				break;
+		}
+		
+		return inputsAvail;
+	}
+	
+	private static MatrixBlock getMatrix(String name, LocalVariableMap vars) {
+		Data dat = vars.get(name);
+		if( !(dat instanceof MatrixObject) )
+			throw new HopsException("Input '"+name+"' not a matrix: "+dat.getDataType());
+		return ((MatrixObject)dat).acquireReadAndRelease();
+	}
+	
+	private static double dotProduct(int[] h1cNnz, int[] h2rNnz) {
+		long fp = 0;
+		for( int j=0; j<h1cNnz.length; j++ )
+			fp += (long)h1cNnz[j] * h2rNnz[j];
+		return fp;
+	}
+}