You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2018/04/18 07:12:51 UTC

[4/5] systemml git commit: [SYSTEMML-2252] Improved parfor rewrite update-in-place result indexing

[SYSTEMML-2252] Improved parfor rewrite update-in-place result indexing

This patch improves the existing update-in-place resulting indexing
rewrite of the parfor optimizer. In detail this includes (1) a fix for
compute the total size of pinned result variables (which incorrectly
double counted the degree of parallelism), and (2) properly compute the
remaining memory estimates without double counting the result indexing
in case this operation is the max memory consumer.

Furthermore, this also includes a cleanup of the parfor cost/memory
estimator to avoid unnecessary tracing and leverage java streams.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/2838cefd
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/2838cefd
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/2838cefd

Branch: refs/heads/master
Commit: 2838cefd9a361992a41173e713ef659738d0aed6
Parents: dfc48ae3
Author: Matthias Boehm <mb...@gmail.com>
Authored: Tue Apr 17 23:03:02 2018 -0700
Committer: Matthias Boehm <mb...@gmail.com>
Committed: Tue Apr 17 23:03:02 2018 -0700

----------------------------------------------------------------------
 .../parfor/opt/CostEstimator.java               | 49 +++++++++----------
 .../parfor/opt/CostEstimatorHops.java           | 48 ++++++++++--------
 .../parfor/opt/CostEstimatorRuntime.java        |  4 +-
 .../parfor/opt/OptimizerConstrained.java        | 10 ++--
 .../parfor/opt/OptimizerRuleBased.java          | 51 +++++++++-----------
 5 files changed, 78 insertions(+), 84 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/2838cefd/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/CostEstimator.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/CostEstimator.java b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/CostEstimator.java
index 2f5392e..167120f 100644
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/CostEstimator.java
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/CostEstimator.java
@@ -20,11 +20,13 @@
 package org.apache.sysml.runtime.controlprogram.parfor.opt;
 
 import java.util.ArrayList;
+import java.util.Collection;
 
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 
 import org.apache.sysml.lops.LopProperties.ExecType;
+import org.apache.sysml.parser.ParForStatementBlock.ResultVar;
 import org.apache.sysml.runtime.controlprogram.parfor.opt.OptNode.ParamType;
 
 /**
@@ -34,9 +36,9 @@ import org.apache.sysml.runtime.controlprogram.parfor.opt.OptNode.ParamType;
  * 
  */
 public abstract class CostEstimator 
-{	
+{
 	protected static final Log LOG = LogFactory.getLog(CostEstimator.class.getName());
-    	
+
 	//default parameters
 	public static final double DEFAULT_EST_PARALLELISM = 1.0; //default degree of parallelism: serial
 	public static final long   FACTOR_NUM_ITERATIONS   = 10; //default problem size
@@ -46,7 +48,7 @@ public abstract class CostEstimator
 
 	public enum TestMeasure {
 		EXEC_TIME,
-		MEMORY_USAGE	
+		MEMORY_USAGE
 	}
 	
 	public enum DataFormat {
@@ -55,6 +57,7 @@ public abstract class CostEstimator
 	}
 	
 	protected boolean _inclCondPart = false;
+	protected Collection<ResultVar> _exclRetVars = null;
 	
 	/**
 	 * Main leaf node estimation method - to be overwritten by specific cost estimators
@@ -93,12 +96,18 @@ public abstract class CostEstimator
 	}
 	
 	public double getEstimate( TestMeasure measure, OptNode node, boolean inclCondPart ) {
-		//temporarily change local flag and get estimate
-		boolean oldInclCondPart = _inclCondPart;
-		_inclCondPart = inclCondPart; 
+		_inclCondPart = inclCondPart; //temporary
+		double val = getEstimate(measure, node, null);
+		_inclCondPart = false;
+		return val;
+	}
+	
+	public double getEstimate(TestMeasure measure, OptNode node, boolean inclCondPart, Collection<ResultVar> retVars) {
+		_inclCondPart = inclCondPart; //temporary
+		_exclRetVars = retVars;
 		double val = getEstimate(measure, node, null);
-		//reset local flag and return
-		_inclCondPart = oldInclCondPart;
+		_inclCondPart = false; 
+		_exclRetVars = null;
 		return val;
 	}
 	
@@ -125,8 +134,6 @@ public abstract class CostEstimator
 		else
 		{
 			//aggreagtion methods for different program block types and measure types
-			//TODO EXEC TIME requires reconsideration of for/parfor/if predicates 
-			//TODO MEMORY requires reconsideration of parfor -> potential overestimation, but safe
 			String tmp = null;
 			double N = -1;
 			switch ( measure )
@@ -190,34 +197,24 @@ public abstract class CostEstimator
 		return val;
 	}
 
-	protected double getDefaultEstimate(TestMeasure measure)  {
+	protected double getDefaultEstimate(TestMeasure measure) {
 		switch( measure ) {
 			case EXEC_TIME:    return DEFAULT_TIME_ESTIMATE;
 			case MEMORY_USAGE: return DEFAULT_MEM_ESTIMATE_CP;
-		}		
+		}
 		return -1;
 	}
 
 	protected double getMaxEstimate( TestMeasure measure, ArrayList<OptNode> nodes, ExecType et ) {
-		double max = Double.MIN_VALUE; //smallest positive value
-		for( OptNode n : nodes )
-			max = Math.max(max, getEstimate(measure, n, et));
-		return max;
+		return nodes.stream().mapToDouble(n -> getEstimate(measure, n, et))
+			.max().orElse(Double.NEGATIVE_INFINITY);
 	}
 
 	protected double getSumEstimate( TestMeasure measure, ArrayList<OptNode> nodes, ExecType et ) {
-		double sum = 0;
-		for( OptNode n : nodes )
-			sum += getEstimate( measure, n, et );
-		return sum;
+		return nodes.stream().mapToDouble(n -> getEstimate(measure, n, et)).sum();
 	}
 
 	protected double getWeightedEstimate( TestMeasure measure, ArrayList<OptNode> nodes, ExecType et ) {
-		double ret = 0;
-		int len = nodes.size();
-		for( OptNode n : nodes )
-			ret += getEstimate( measure, n, et );
-		ret /= len; //weighting
-		return ret;
+		return nodes.stream().mapToDouble(n -> getEstimate(measure, n, et)).sum() / nodes.size(); //weighting
 	}
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/2838cefd/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/CostEstimatorHops.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/CostEstimatorHops.java b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/CostEstimatorHops.java
index 55d9c0c..3cb0c6f 100644
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/CostEstimatorHops.java
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/CostEstimatorHops.java
@@ -21,8 +21,10 @@ package org.apache.sysml.runtime.controlprogram.parfor.opt;
 
 import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.hops.Hop;
+import org.apache.sysml.hops.LeftIndexingOp;
 import org.apache.sysml.hops.OptimizerUtils;
 import org.apache.sysml.lops.LopProperties.ExecType;
+import org.apache.sysml.parser.ParForStatementBlock.ResultVar;
 import org.apache.sysml.runtime.DMLRuntimeException;
 import org.apache.sysml.runtime.controlprogram.parfor.opt.OptNode.NodeType;
 import org.apache.sysml.runtime.controlprogram.parfor.opt.Optimizer.CostModelType;
@@ -30,15 +32,15 @@ import org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyze
 
 public class CostEstimatorHops extends CostEstimator
 {
-	public static double DEFAULT_MEM_MR = -1;
-	public static double DEFAULT_MEM_SP = 20*1024*1024;
+	public static final double DEFAULT_MEM_SP = 20*1024*1024;
+	public static final double DEFAULT_MEM_MR;
 	
 	private OptTreePlanMappingAbstract _map = null;
 	
 	static {
-		DEFAULT_MEM_MR = DEFAULT_MEM_ESTIMATE_MR; //20MB
-		if( InfrastructureAnalyzer.isLocalMode() )
-			DEFAULT_MEM_MR = DEFAULT_MEM_MR + InfrastructureAnalyzer.getRemoteMaxMemorySortBuffer();
+		DEFAULT_MEM_MR = DEFAULT_MEM_ESTIMATE_MR //20MB
+			+ (InfrastructureAnalyzer.isLocalMode() ?
+			InfrastructureAnalyzer.getRemoteMaxMemorySortBuffer() : 0);
 	}
 	
 	
@@ -61,30 +63,26 @@ public class CostEstimatorHops extends CostEstimator
 		
 		//handle specific cases 
 		double DEFAULT_MEM_REMOTE = OptimizerUtils.isSparkExecutionMode() ? 
-								DEFAULT_MEM_SP : DEFAULT_MEM_MR;
+			DEFAULT_MEM_SP : DEFAULT_MEM_MR;
 		
-		if( value >= DEFAULT_MEM_REMOTE )   	  
+		if( value >= DEFAULT_MEM_REMOTE )
 		{
 			//check for CP estimate but MR type
-			if( h.getExecType()==ExecType.MR ) 
-			{
+			if( h.getExecType()==ExecType.MR ) {
 				value = DEFAULT_MEM_REMOTE;
 			}
 			//check for CP estimate but Spark type (include broadcast requirements)
-			else if( h.getExecType()==ExecType.SPARK )
-			{
+			else if( h.getExecType()==ExecType.SPARK ) {
 				value = DEFAULT_MEM_REMOTE + h.getSpBroadcastSize();
 			}
 			//check for invalid cp memory estimate
-			else if ( h.getExecType()==ExecType.CP && value >= OptimizerUtils.getLocalMemBudget() )
-			{
+			else if ( h.getExecType()==ExecType.CP && value >= OptimizerUtils.getLocalMemBudget() ) {
 				if( DMLScript.rtplatform != DMLScript.RUNTIME_PLATFORM.SINGLE_NODE && h.getForcedExecType()==null )
 					LOG.warn("Memory estimate larger than budget but CP exec type (op="+h.getOpString()+", name="+h.getName()+", memest="+h.getMemEstimate()+").");
 				value = DEFAULT_MEM_REMOTE;
 			}
 			//check for non-existing exec type
-			else if ( h.getExecType()==null)
-			{
+			else if ( h.getExecType()==null) {
 				//note: if exec type is 'null' lops have never been created (e.g., r(T) for tsmm),
 				//in that case, we do not need to raise a warning 
 				value = DEFAULT_MEM_REMOTE;
@@ -92,18 +90,23 @@ public class CostEstimatorHops extends CostEstimator
 		}
 		
 		//check for forced runtime platform
-		if( h.getForcedExecType()==ExecType.MR  || h.getForcedExecType()==ExecType.SPARK) 
-		{
+		if( h.getForcedExecType()==ExecType.MR  || h.getForcedExecType()==ExecType.SPARK) {
 			value = DEFAULT_MEM_REMOTE;
 		}
 		
-		if( value <= 0 ) //no mem estimate
-		{
+		if( value <= 0 ) { //no mem estimate
 			LOG.warn("Cannot get memory estimate for hop (op="+h.getOpString()+", name="+h.getName()+", memest="+h.getMemEstimate()+").");
 			value = CostEstimator.DEFAULT_MEM_ESTIMATE_CP;
 		}
 		
-		LOG.trace("Memory estimate "+h.getName()+", "+h.getOpString()+"("+node.getExecType()+")"+"="+OptimizerRuleBased.toMB(value));
+		//correction for disabled result indexing
+		value = (_exclRetVars!=null && h instanceof LeftIndexingOp
+			&& ResultVar.contains(_exclRetVars, h.getName())) ? 0 : value;
+		
+		if( LOG.isTraceEnabled() ) {
+			LOG.trace("Memory estimate "+h.getName()+", "+h.getOpString()
+				+"("+node.getExecType()+")"+"="+OptimizerRuleBased.toMB(value));
+		}
 		
 		return value;
 	}
@@ -125,7 +128,10 @@ public class CostEstimatorHops extends CostEstimator
 		if( value <= 0 ) //no mem estimate
 			value = CostEstimator.DEFAULT_MEM_ESTIMATE_CP;
 		
-		LOG.trace("Memory estimate (forced exec type) "+h.getName()+", "+h.getOpString()+"("+node.getExecType()+")"+"="+OptimizerRuleBased.toMB(value));
+		if( LOG.isTraceEnabled() ) {
+			LOG.trace("Memory estimate (forced exec type) "+h.getName()+", "
+				+h.getOpString()+"("+node.getExecType()+")"+"="+OptimizerRuleBased.toMB(value));
+		}
 		
 		return value;
 	}

http://git-wip-us.apache.org/repos/asf/systemml/blob/2838cefd/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/CostEstimatorRuntime.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/CostEstimatorRuntime.java b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/CostEstimatorRuntime.java
index 9a842a7..ee05af0 100644
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/CostEstimatorRuntime.java
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/CostEstimatorRuntime.java
@@ -71,10 +71,10 @@ public class CostEstimatorRuntime extends CostEstimator
 		//(currently only called for entire parfor program in order to
 		//decide for LOCAL vs REMOTE parfor execution)
 		double ret = DEFAULT_TIME_ESTIMATE;
-		boolean isCP = (et == ExecType.CP || et == null);		
+		boolean isCP = (et == ExecType.CP || et == null);
 		if( !node.isLeaf() && isCP ) {
 			ProgramBlock pb = (ProgramBlock)_map.getMappedProg(node.getID())[1];
-			ret = CostEstimationWrapper.getTimeEstimate(pb, _ec, true);				
+			ret = CostEstimationWrapper.getTimeEstimate(pb, _ec, true);
 		}
 		return ret;
 	}

http://git-wip-us.apache.org/repos/asf/systemml/blob/2838cefd/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizerConstrained.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizerConstrained.java b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizerConstrained.java
index 68731c1..4088e63 100644
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizerConstrained.java
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizerConstrained.java
@@ -54,9 +54,7 @@ import org.apache.sysml.runtime.controlprogram.parfor.opt.OptNode.ParamType;
  * - 4) rewrite set execution strategy
  * - 9) rewrite set degree of parallelism
  * - 10) rewrite set task partitioner
- * - 11) rewrite set result merge 		 		
- *
- * TODO generalize for nested parfor (currently only awareness of top-level constraints, if present leave child as they are)
+ * - 11) rewrite set result merge
  *
  */
 public class OptimizerConstrained extends OptimizerRuleBased
@@ -167,7 +165,7 @@ public class OptimizerConstrained extends OptimizerRuleBased
 
 			//rewrite 14:
 			HashSet<ResultVar> inplaceResultVars = new HashSet<>();
-			super.rewriteSetInPlaceResultIndexing(pn, M1, ec.getVariables(), inplaceResultVars, ec);
+			super.rewriteSetInPlaceResultIndexing(pn, _cost, ec.getVariables(), inplaceResultVars, ec);
 
 			//rewrite 15:
 			super.rewriteDisableCPCaching(pn, inplaceResultVars, ec.getVariables());
@@ -183,7 +181,7 @@ public class OptimizerConstrained extends OptimizerRuleBased
 
 			// rewrite 14: set in-place result indexing
 			HashSet<ResultVar> inplaceResultVars = new HashSet<>();
-			super.rewriteSetInPlaceResultIndexing(pn, M1, ec.getVariables(), inplaceResultVars, ec);
+			super.rewriteSetInPlaceResultIndexing(pn, _cost, ec.getVariables(), inplaceResultVars, ec);
 
 			if( !OptimizerUtils.isSparkExecutionMode() ) {
 				// rewrite 16: runtime piggybacking
@@ -316,7 +314,7 @@ public class OptimizerConstrained extends OptimizerRuleBased
 		if( !pn.getParam(ParamType.TASK_PARTITIONER).equals(PTaskPartitioner.UNSPECIFIED.toString()) )
 		{
 			ParForProgramBlock pfpb = (ParForProgramBlock) OptTreeConverter
-                    .getAbstractPlanMapping().getMappedProg(pn.getID())[1];
+				.getAbstractPlanMapping().getMappedProg(pn.getID())[1];
 			pfpb.setTaskPartitioner(PTaskPartitioner.valueOf(pn.getParam(ParamType.TASK_PARTITIONER)));
 			String tsExt = "";
 			if( pn.getParam(ParamType.TASK_SIZE)!=null )

http://git-wip-us.apache.org/repos/asf/systemml/blob/2838cefd/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java
index 536acb3..2bf8fe5 100644
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java
@@ -166,25 +166,22 @@ public class OptimizerRuleBased extends Optimizer
 	protected double _lm = -1; //local memory constraint
 	protected double _rm = -1; //remote memory constraint (mappers)
 	protected double _rm2 = -1; //remote memory constraint (reducers)
-		
+	
 	protected CostEstimator _cost = null;
 
 	@Override
-	public CostModelType getCostModelType() 
-	{
+	public CostModelType getCostModelType() {
 		return CostModelType.STATIC_MEM_METRIC;
 	}
 
 
 	@Override
-	public PlanInputType getPlanInputType() 
-	{
+	public PlanInputType getPlanInputType() {
 		return PlanInputType.ABSTRACT_PLAN;
 	}
 
 	@Override
-	public POptMode getOptMode() 
-	{
+	public POptMode getOptMode() {
 		return POptMode.RULEBASED;
 	}
 	
@@ -261,7 +258,7 @@ public class OptimizerRuleBased extends Optimizer
 			if( flagRecompMR ){
 				//rewrite 5: set operations exec type
 				rewriteSetOperationsExecType( pn, flagRecompMR );
-				M1 = _cost.getEstimate(TestMeasure.MEMORY_USAGE, pn); //reestimate 		
+				M1 = _cost.getEstimate(TestMeasure.MEMORY_USAGE, pn); //reestimate
 			}
 			
 			// rewrite 6: data colocation
@@ -287,7 +284,7 @@ public class OptimizerRuleBased extends Optimizer
 			
 			// rewrite 14: set in-place result indexing
 			HashSet<ResultVar> inplaceResultVars = new HashSet<>();
-			rewriteSetInPlaceResultIndexing(pn, M1, ec.getVariables(), inplaceResultVars, ec);
+			rewriteSetInPlaceResultIndexing(pn, _cost, ec.getVariables(), inplaceResultVars, ec);
 			
 			// rewrite 15: disable caching
 			rewriteDisableCPCaching(pn, inplaceResultVars, ec.getVariables());
@@ -302,7 +299,7 @@ public class OptimizerRuleBased extends Optimizer
 			
 			// rewrite 14: set in-place result indexing
 			HashSet<ResultVar> inplaceResultVars = new HashSet<>();
-			rewriteSetInPlaceResultIndexing(pn, M1, ec.getVariables(), inplaceResultVars, ec);
+			rewriteSetInPlaceResultIndexing(pn, _cost, ec.getVariables(), inplaceResultVars, ec);
 			
 			if( !OptimizerUtils.isSparkExecutionMode() ) {
 				// rewrite 16: runtime piggybacking
@@ -1609,7 +1606,7 @@ public class OptimizerRuleBased extends Optimizer
 			if( inMatrix.equals(varName) )
 			{
 				//check that all parents are transpose-safe operations
-				//(even a transient write would not be safe due to indirection into other DAGs)			
+				//(even a transient write would not be safe due to indirection into other DAGs)
 				ArrayList<Hop> parent = h.getParent();
 				for( Hop p : parent )
 					ret &= p.isTransposeSafe();
@@ -1624,7 +1621,7 @@ public class OptimizerRuleBased extends Optimizer
 	//REWRITE set in-place result indexing
 	///
 
-	protected void rewriteSetInPlaceResultIndexing(OptNode pn, double M, LocalVariableMap vars, HashSet<ResultVar> inPlaceResultVars, ExecutionContext ec) 
+	protected void rewriteSetInPlaceResultIndexing(OptNode pn, CostEstimator cost, LocalVariableMap vars, HashSet<ResultVar> inPlaceResultVars, ExecutionContext ec) 
 	{
 		//assertions (warnings of corrupt optimizer decisions)
 		if( pn.getNodeType() != NodeType.PARFOR )
@@ -1639,16 +1636,17 @@ public class OptimizerRuleBased extends Optimizer
 		//only if all fit pinned in remaining budget, we apply this rewrite.
 		ArrayList<ResultVar> retVars = pfpb.getResultVariables();
 		
-		//compute total sum of pinned result variable memory
-		double sum = computeTotalSizeResultVariables(retVars, vars, pfpb.getDegreeOfParallelism());
-		
-		//NOTE: currently this rule is too conservative (the result variable is assumed to be dense and
-		//most importantly counted twice if this is part of the maximum operation)
-		double totalMem = Math.max((M+sum), rComputeSumMemoryIntermediates(pn, new HashSet<ResultVar>()));
-		
-		//optimization decision
-		if( rHasOnlyInPlaceSafeLeftIndexing(pn, retVars) ) //basic correctness constraint
+		//basic correctness constraint
+		double totalMem = -1;
+		if( rHasOnlyInPlaceSafeLeftIndexing(pn, retVars) )
 		{
+			//compute total sum of pinned result variable memory 
+			double sum = computeTotalSizeResultVariables(retVars, vars, pfpb.getDegreeOfParallelism());
+		
+			//compute memory estimate without result indexing, and total sum per worker
+			double M = cost.getEstimate(TestMeasure.MEMORY_USAGE, pn, true, retVars);
+			totalMem = M + sum;
+			
 			//result update in-place for MR/Spark (w/ remote memory constraint)
 			if( (  pfpb.getExecMode() == PExecMode.REMOTE_MR_DP || pfpb.getExecMode() == PExecMode.REMOTE_MR
 				|| pfpb.getExecMode() == PExecMode.REMOTE_SPARK_DP || pfpb.getExecMode() == PExecMode.REMOTE_SPARK) 
@@ -1706,14 +1704,9 @@ public class OptimizerRuleBased extends Optimizer
 			if( !(dat instanceof MatrixObject) )
 				continue;
 			MatrixObject mo = (MatrixObject)dat;
-			if( mo.getNnz() == 0 ) 
-				sum += OptimizerUtils.estimateSizeExactSparsity(mo.getNumRows(), mo.getNumColumns(), 1.0);
-			else {
-				// Every worker will consume memory for (MatrixSize/k + nnz) data.
-				// This is applicable only when there is non-zero nnz. 
-				sum += (k+1) * (OptimizerUtils.estimateSizeExactSparsity(mo.getNumRows(), 
-					mo.getNumColumns(), Math.min((1.0/k)+mo.getSparsity(), 1.0)));
-			} 
+			// every worker will consume memory for at most (max_nnz/k + in_nnz)
+			sum += (OptimizerUtils.estimateSizeExactSparsity(mo.getNumRows(), 
+				mo.getNumColumns(), Math.min((1.0/k)+mo.getSparsity(), 1.0)));
 		}
 		return sum;
 	}