You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2017/09/01 19:12:11 UTC

[3/3] systemml git commit: [SYSTEMML-1888] Fix parfor optimizer exec type selection (msvm, kmeans)

[SYSTEMML-1888] Fix parfor optimizer exec type selection (msvm,kmeans)

This patch makes a number of smaller fixes to the parfor optimizer
execution type selection to enable the compilation of remote parfor
plans in the presence of many inner iterations but relatively small
input data. In detail, this includes (besides various cleanups): 

(1) Fixes for determining the number of inner iterations (include for
and while loops with defaults if necessary), 

(2) A modified minimum datasize threshold that is now scaled by the
number of inner iterations, and

(3) A better handling of what-if memory estimation (e.g., all operations
to CP) that does not account for unnecessary hops that are never
compiled to lops due to hop-lop rewrites.

On the perftest MSVM 1M x 1K, sparse scenario with 150 classes and 25
iterations, this patch improved performance from 288s to 94s on a 1+6
node cluster (including read and spark context creation).


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/8dbc9302
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/8dbc9302
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/8dbc9302

Branch: refs/heads/master
Commit: 8dbc93022a01aae309354c7b2b2f0eee9ec11aad
Parents: 9178a95
Author: Matthias Boehm <mb...@gmail.com>
Authored: Fri Sep 1 01:32:11 2017 -0700
Committer: Matthias Boehm <mb...@gmail.com>
Committed: Fri Sep 1 12:13:27 2017 -0700

----------------------------------------------------------------------
 .../parfor/RemoteParForSparkWorker.java         |   4 +-
 .../controlprogram/parfor/opt/OptNode.java      | 328 +++++++------------
 .../parfor/opt/OptTreeConverter.java            |  25 +-
 .../parfor/opt/OptimizerRuleBased.java          |  15 +-
 4 files changed, 145 insertions(+), 227 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/8dbc9302/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/RemoteParForSparkWorker.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/RemoteParForSparkWorker.java b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/RemoteParForSparkWorker.java
index cd4a673..e1410da 100644
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/RemoteParForSparkWorker.java
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/RemoteParForSparkWorker.java
@@ -84,7 +84,7 @@ public class RemoteParForSparkWorker extends ParWorker implements PairFlatMapFun
 		ArrayList<String> tmp = RemoteParForUtils.exportResultVariables( _workerID, _ec.getVariables(), _resultVars );
 		for( String val : tmp )
 			ret.add(new Tuple2<Long,String>(_workerID, val));
-			
+		
 		return ret.iterator();
 	}
 
@@ -102,7 +102,7 @@ public class RemoteParForSparkWorker extends ParWorker implements PairFlatMapFun
 		//parse and setup parfor body program
 		ParForBody body = ProgramConverter.parseParForBody(_prog, (int)_workerID);
 		_childBlocks = body.getChildBlocks();
-		_ec          = body.getEc();				
+		_ec          = body.getEc();
 		_resultVars  = body.getResultVarNames();
 		_numTasks    = 0;
 		_numIters    = 0;

http://git-wip-us.apache.org/repos/asf/systemml/blob/8dbc9302/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptNode.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptNode.java b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptNode.java
index 22126fe..193ce3e 100644
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptNode.java
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptNode.java
@@ -38,7 +38,6 @@ import org.apache.sysml.runtime.controlprogram.ParForProgramBlock.PartitionForma
  */
 public class OptNode 
 {
-	
 	public enum NodeType{
 		GENERIC,
 		FUNCCALL,
@@ -47,7 +46,11 @@ public class OptNode
 		FOR,
 		PARFOR,
 		INST,
-		HOP
+		HOP;
+		public boolean isLoop() {
+			return this == WHILE ||
+				this == FOR || this == PARFOR;
+		}
 	}
 	
 	public enum ExecType { 
@@ -96,260 +99,203 @@ public class OptNode
 	private int                       _beginLine = -1;
 	private int                       _endLine = -1;
 	
-	public OptNode( NodeType type )
-	{
+	public OptNode( NodeType type ) {
 		this(type, null);
 	}
 
-	public OptNode( NodeType ntype, ExecType etype )
-	{
+	public OptNode( NodeType ntype, ExecType etype ) {
 		_ntype = ntype;
 		_etype = etype;
-		
 		_k = 1;
 	}
 	
 	///////
 	//getters and setters
 	
-	public NodeType getNodeType() 
-	{
+	public NodeType getNodeType() {
 		return _ntype;
 	}
 	
-	public void setNodeType(NodeType type) 
-	{
+	public void setNodeType(NodeType type) {
 		_ntype = type;
 	}
 	
-	public ExecType getExecType() 
-	{
+	public ExecType getExecType() {
 		return _etype;
 	}
 	
-	public void setExecType(ExecType type) 
-	{
+	public void setExecType(ExecType type) {
 		_etype = type;
 	}
 	
-	public void setID( long id )
-	{
+	public void setID( long id ) {
 		_id = id;
 	}
 	
-	public long getID( )
-	{
+	public long getID( ) {
 		return _id;
 	}
 	
-	public void addParam(ParamType ptype, String val)
-	{
+	public void addParam(ParamType ptype, String val) {
 		if( _params == null )
 			_params = new HashMap<ParamType, String>();
-		
 		_params.put(ptype, val);
 	}
 
-	public void setParams( HashMap<ParamType,String> params )
-	{
+	public void setParams( HashMap<ParamType,String> params ) {
 		_params = params;
 	}
 	
-	public String getParam( ParamType type )
-	{
-		String ret = null;
-		if( _params != null )
-			ret = _params.get(type);
-		return ret;
+	public String getParam( ParamType type ) {
+		return (_params != null) ?
+			_params.get(type) : null;
 	}
 	
-	public int getBeginLine()
-	{
+	public int getBeginLine() {
 		return _beginLine;
 	}
 	
-	public void setBeginLine( int line )
-	{
+	public void setBeginLine( int line ) {
 		_beginLine = line;
 	}
 	
-	public int getEndLine()
-	{
+	public int getEndLine() {
 		return _endLine;
 	}
 	
-	public void setEndLine( int line )
-	{
+	public void setEndLine( int line ) {
 		_endLine = line;
 	}
 
-	public void setLineNumbers( int begin, int end )
-	{
+	public void setLineNumbers( int begin, int end ) {
 		setBeginLine( begin );
 		setEndLine( end );
 	}
 	
-	public void addChild( OptNode child )
-	{
+	public void addChild( OptNode child ) {
 		if( _childs==null )
 			_childs = new ArrayList<OptNode>();
-		
 		_childs.add( child );
 	}
 	
-	public void addChilds( ArrayList<OptNode> childs )
-	{
+	public void addChilds( ArrayList<OptNode> childs ) {
 		if( _childs==null )
 			_childs = new ArrayList<OptNode>();
-		
-		_childs.addAll( childs );		
+		_childs.addAll( childs );
 	}
 	
-	public void setChilds(ArrayList<OptNode> childs) 
-	{
+	public void setChilds(ArrayList<OptNode> childs) {
 		_childs = childs;
 	}
 	
-	public ArrayList<OptNode> getChilds() 
-	{
+	public ArrayList<OptNode> getChilds() {
 		return _childs;
 	}
 	
-	
-	public int getK() 
-	{
+	public int getK() {
 		return _k;
 	}
 
-	public void setK(int k) 
-	{
+	public void setK(int k) {
 		_k = k;
 	}
 	
-	public OptNodeStatistics getStatistics()
-	{
+	public OptNodeStatistics getStatistics() {
 		return _stats;
 	}
 	
-	public void setStatistics(OptNodeStatistics stats)
-	{
+	public void setStatistics(OptNodeStatistics stats) {
 		_stats = stats;
 	}
 
-	public boolean exchangeChild(OptNode oldNode, OptNode newNode) 
-	{
-		boolean ret = false;
-		
-		if( _childs != null )
-			for( int i=0; i<_childs.size(); i++ )
-				if( _childs.get(i) == oldNode )
-				{
-					_childs.set(i, newNode);
-					ret = true;
-				}
+	public boolean exchangeChild(OptNode oldNode, OptNode newNode) {
+		if( isLeaf() )
+			return false;
 		
+		boolean ret = false;
+		for( int i=0; i<_childs.size(); i++ )
+			if( _childs.get(i) == oldNode ) {
+				_childs.set(i, newNode);
+				ret = true;
+			}
 		return ret;
 	}
 
-	public boolean isLeaf()
-	{
+	public boolean isLeaf() {
 		return ( _childs == null || _childs.isEmpty() );
 	}
 
-	public boolean hasOnlySimpleChilds()
-	{
-		boolean ret = true;
-		if( !isLeaf() )
-			for( OptNode n : _childs ) {
-				if( n.getNodeType()==NodeType.GENERIC )
-					ret &= n.hasOnlySimpleChilds();
-				//functions, loops, branches
-				else if( n.getNodeType()!=NodeType.HOP )
-					return false;
-			}
+	public boolean hasOnlySimpleChilds() {
+		if( isLeaf() )
+			return true;
 		
+		boolean ret = true;
+		for( OptNode n : _childs ) {
+			if( n.getNodeType()==NodeType.GENERIC )
+				ret &= n.hasOnlySimpleChilds();
+			//functions, loops, branches
+			else if( n.getNodeType()!=NodeType.HOP )
+				return false;
+		}
 		return ret;
 	}
 
-	public String getInstructionName() 
-	{
+	public String getInstructionName() {
 		return String.valueOf(_etype) + Lop.OPERAND_DELIMITOR + getParam(ParamType.OPSTRING);
 	}
 
-	public boolean isRecursive()
-	{
-		boolean ret = false;
+	public boolean isRecursive() {
 		String rec = getParam(ParamType.RECURSIVE_CALL);
-		if( rec != null )
-			ret = Boolean.parseBoolean(rec);
-		return ret;
+		return (rec != null) ?
+			Boolean.parseBoolean(rec) : false;
 	}
 	
 
 	///////
 	//recursive methods
 
-	public Collection<OptNode> getNodeList()
-	{
+	public Collection<OptNode> getNodeList() {
 		Collection<OptNode> nodes = new LinkedList<OptNode>();
-		
 		if(!isLeaf())
 			for( OptNode n : _childs )
 				nodes.addAll( n.getNodeList() );
 		nodes.add(this);
-		
 		return nodes;
 	}
 
-	public Collection<OptNode> getNodeList( ExecType et )
-	{
+	public Collection<OptNode> getNodeList( ExecType et ) {
 		Collection<OptNode> nodes = new LinkedList<OptNode>();
-		
 		if(!isLeaf())
 			for( OptNode n : _childs )
 				nodes.addAll( n.getNodeList( et ) );
-		
 		if( _etype == et )
 			nodes.add(this);
-		
 		return nodes;
 	}
 
-	public Collection<OptNode> getRelevantNodeList()
-	{
+	public Collection<OptNode> getRelevantNodeList() {
 		Collection<OptNode> nodes = new LinkedList<OptNode>();
-		
 		if( !isLeaf() )
-		{
 			for( OptNode n : _childs )
 				nodes.addAll( n.getRelevantNodeList() );
-		}
-		 
 		if( _ntype == NodeType.PARFOR || _ntype == NodeType.HOP )
-		{
 			nodes.add(this);
-		}
-		
 		return nodes;
 	}
 	
 	
-
-	
 	/**
 	 * Set the plan to a parallel degree of 1 (serial execution).
 	 */
-	public void setSerialParFor()
-	{
+	public void setSerialParFor() {
 		//process parfor nodes
-		if( _ntype == NodeType.PARFOR )
-		{
+		if( _ntype == NodeType.PARFOR ) {
 			_k = 1;
 			_etype = ExecType.CP;
 		}
 		
 		//process childs
-		if( _childs != null )
+		if( !isLeaf() )
 			for( OptNode n : _childs )
 				n.setSerialParFor();
 	}
@@ -359,10 +305,9 @@ public class OptNode
 	 * 
 	 * @return number of plan nodes
 	 */
-	public int size() 
-	{
+	public int size() {
 		int count = 1; //self
-		if( _childs != null )
+		if( !isLeaf() )
 			for( OptNode n : _childs )
 				count += n.size();
 		return count;
@@ -374,27 +319,23 @@ public class OptNode
 	 * 
 	 * @return true of all program blocks and instructions execute on CP
 	 */
-	public boolean isCPOnly()
-	{
-		boolean ret = (_etype == ExecType.CP);		
-		if( _childs != null )
-			for( OptNode n : _childs )
-			{
+	public boolean isCPOnly() {
+		boolean ret = (_etype == ExecType.CP);
+		if( !isLeaf() )
+			for( OptNode n : _childs ) {
 				if( !ret ) break; //early abort if already false
 				ret &= n.isCPOnly();
 			}
 		return ret;
 	}
 
-	public int getTotalK()
-	{
+	public int getTotalK() {
 		int k = 1;		
-		if( _childs != null )
+		if( !isLeaf() )
 			for( OptNode n : _childs )
 				k = Math.max(k, n.getTotalK() );
 		
-		if( _ntype == NodeType.PARFOR )
-		{
+		if( _ntype == NodeType.PARFOR ) {
 			if( _etype==ExecType.CP )
 				k = _k * k;
 			else //MR
@@ -404,104 +345,80 @@ public class OptNode
 		return k;
 	}
 
-	public long getMaxC( long N )
-	{
+	public long getMaxC( long N ) {
 		long maxc = N;
-		if( _childs != null )
+		if( !isLeaf() )
 			for( OptNode n : _childs )
-				maxc = Math.min(maxc, n.getMaxC( N ) );
+				maxc = Math.min(maxc, n.getMaxC( N ));
 		
-		if( _ntype == NodeType.HOP )
-		{
-			String ts = getParam( ParamType.TASK_SIZE );
+		if( _ntype == NodeType.HOP ) {
+			String ts = getParam(ParamType.TASK_SIZE);
 			if( ts != null )
-				maxc = Math.min(maxc, Integer.parseInt(ts) );
+				maxc = Math.min(maxc, Integer.parseInt(ts));
 		}
 		
-		if(    _ntype == NodeType.PARFOR 
-		    && _etype == ExecType.CP    )
-		{
+		if( _ntype == NodeType.PARFOR && _etype == ExecType.CP)
 			maxc = maxc / _k; //intdiv
-		}
 		
 		return maxc;
 	}
 
-	public boolean hasNestedParallelism( boolean flagNested )
-	{
+	public boolean hasNestedParallelism( boolean flagNested ) {
 		boolean ret = false;
-		
-		if( _ntype == NodeType.PARFOR )
-		{
+		//check for parfor in nested context
+		if( _ntype == NodeType.PARFOR ) {
 			if( flagNested ) 
 				return true;
 			flagNested = true;
 		}
 		
-		if( _childs != null )
-			for( OptNode n : _childs )
-			{
-				if( ret ) break; //early abort if already true
-				ret |= n.hasNestedParallelism( flagNested );
-			}
-		
-			ret = true;
-			
+		//recursively process children
+		if( !isLeaf() )
+			for( int i=0; i<_childs.size() && !ret; i++ )
+				ret |= _childs.get(i).hasNestedParallelism( flagNested );
 		return ret;
 	}
 
-	public boolean hasNestedPartitionReads( boolean flagNested )
-	{
-		boolean ret = false;
-		if( isLeaf() )
-		{
+	public boolean hasNestedPartitionReads( boolean flagNested ) {
+		if( isLeaf() ) {
 			//partitioned read identified by selected partition format
 			String tmp = getParam(ParamType.DATA_PARTITION_FORMAT);
-			ret = ( tmp !=null 
-					&& PartitionFormat.valueOf(tmp)._dpf!=PDataPartitionFormat.NONE 
-					&& flagNested );
-		}
-		else
-		{
-			for( OptNode n : _childs )
-			{
-				if( n._ntype == NodeType.PARFOR || n._ntype == NodeType.FOR || n._ntype == NodeType.WHILE )
-					flagNested = true;
-				
-				ret |= n.hasNestedPartitionReads( flagNested );
-				if( ret ) break; //early abort if already true
-			}
+			return ( tmp !=null && flagNested
+				&& PartitionFormat.valueOf(tmp)._dpf!=PDataPartitionFormat.NONE);
 		}
 		
+		boolean ret = false;
+		for( int i=0; i<_childs.size() && !ret; i++ ) {
+			OptNode n = _childs.get(i);
+			if( n._ntype.isLoop() )
+				flagNested = true;
+			ret |= n.hasNestedPartitionReads( flagNested );
+		}
 		return ret;
 	}
 
-	public void checkAndCleanupLeafNodes() 
-	{
-		if( _childs != null )
-			for( int i=0; i<_childs.size(); i++ )
-			{
-				OptNode n = _childs.get(i);
-				n.checkAndCleanupLeafNodes();
-				if( n.isLeaf() && n._ntype != NodeType.HOP && n._ntype != NodeType.INST 
-					&& n._ntype != NodeType.FUNCCALL ) // && n._ntype != NodeType.PARFOR
-				{
-					_childs.remove(i);
-					i--;
-				}
+	public void checkAndCleanupLeafNodes() {
+		if( isLeaf() )
+			return;
+		for( int i=0; i<_childs.size(); i++ ) {
+			OptNode n = _childs.get(i);
+			n.checkAndCleanupLeafNodes();
+			if( n.isLeaf() && n._ntype != NodeType.HOP && n._ntype != NodeType.INST 
+				&& n._ntype != NodeType.FUNCCALL ) {
+				_childs.remove(i);
+				i--;
 			}
+		}
 	}
 
-	public void checkAndCleanupRecursiveFunc(Set<String> stack) 
-	{
+	public void checkAndCleanupRecursiveFunc(Set<String> stack) {
 		//recursive invocation
 		if( !isLeaf() )
 			for( OptNode n : _childs )
 				n.checkAndCleanupRecursiveFunc( stack );
 	
 		//collect and update func info
-		if(_ntype == NodeType.FUNCCALL)
-		{
+		if(_ntype == NodeType.FUNCCALL) {
 			String rec = getParam(ParamType.RECURSIVE_CALL);
 			String fname = getParam(ParamType.OPSTRING);
 			if( rec != null && Boolean.parseBoolean(rec) ) 
@@ -573,23 +490,24 @@ public class OptNode
 	}
 
 	/**
-	 * Determines the maximum problem size of all children.
+	 * Determines the maximum problem size, in terms of the maximum
+	 * total number of inner loop iterations, of the entire subtree.
 	 * 
 	 * @return maximum problem size
 	 */
-	public long getMaxProblemSize() 
-	{
-		long max = 0;
-		if( _childs != null )
+	public long getMaxProblemSize() {
+		//recursively process children
+		long max = 1;
+		if( !isLeaf() )
 			for( OptNode n : _childs )
-				max = Math.max(max, n.getMaxProblemSize());		
-		else
-			max = 1;
+				max = Math.max(max, n.getMaxProblemSize());
 		
-		if( _ntype == NodeType.PARFOR )
-			max = max * Long.parseLong(_params.get(ParamType.NUM_ITERATIONS));
-
+		//scale problem size by number of loop iterations
+		if( _ntype.isLoop() && !isLeaf() ) {
+			String numIter = getParam(ParamType.NUM_ITERATIONS);
+			max *= (numIter != null) ? Long.parseLong(numIter) :
+				CostEstimator.FACTOR_NUM_ITERATIONS;
+		}
 		return max;
 	}
-
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/8dbc9302/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptTreeConverter.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptTreeConverter.java b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptTreeConverter.java
index d8cad78..1aae331 100644
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptTreeConverter.java
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptTreeConverter.java
@@ -358,12 +358,11 @@ public class OptTreeConverter
 			
 			//process body
 			int len = ws.getBody().size();
-			for( int i=0; i<wpb.getChildBlocks().size() && i<len; i++ )
-			{
+			for( int i=0; i<wpb.getChildBlocks().size() && i<len; i++ ) {
 				ProgramBlock lpb = wpb.getChildBlocks().get(i);
 				StatementBlock lsb = ws.getBody().get(i);
 				node.addChild( rCreateAbstractOptNode(lsb,lpb,vars,false, memo) );
-			}			
+			}
 		}
 		else if( pb instanceof ForProgramBlock && sb instanceof ForStatementBlock && !(pb instanceof ParForProgramBlock) )
 		{
@@ -392,12 +391,11 @@ public class OptTreeConverter
 			
 			//process body
 			int len = fs.getBody().size();
-			for( int i=0; i<fpb.getChildBlocks().size() && i<len; i++ )
-			{
+			for( int i=0; i<fpb.getChildBlocks().size() && i<len; i++ ) {
 				ProgramBlock lpb = fpb.getChildBlocks().get(i);
 				StatementBlock lsb = fs.getBody().get(i);
 				node.addChild( rCreateAbstractOptNode(lsb,lpb,vars,false, memo) );
-			}	
+			}
 		}
 		else if( pb instanceof ParForProgramBlock && sb instanceof ParForStatementBlock )
 		{
@@ -409,11 +407,10 @@ public class OptTreeConverter
 			_hlMap.putProgMapping(sb, pb, node);
 			node.setK( fpb.getDegreeOfParallelism() );
 			long N = fpb.getNumIterations();
-			node.addParam(ParamType.NUM_ITERATIONS, (N!=-1) ? String.valueOf(N) : 
-															  String.valueOf(CostEstimator.FACTOR_NUM_ITERATIONS));
-
-			switch(fpb.getExecMode())
-			{
+			node.addParam(ParamType.NUM_ITERATIONS, (N!=-1) ? String.valueOf(N) :
+				String.valueOf(CostEstimator.FACTOR_NUM_ITERATIONS));
+			
+			switch(fpb.getExecMode()) {
 				case LOCAL:
 					node.setExecType(ExecType.CP);
 					break;
@@ -429,8 +426,7 @@ public class OptTreeConverter
 					node.setExecType(null);
 			}
 			
-			if( !topLevel )
-			{
+			if( !topLevel ) {
 				fsb.getFromHops().resetVisitStatus();
 				fsb.getToHops().resetVisitStatus();
 				if( fsb.getIncrementHops()!=null )
@@ -443,8 +439,7 @@ public class OptTreeConverter
 			
 			//process body
 			int len = fs.getBody().size();
-			for( int i=0; i<fpb.getChildBlocks().size() && i<len; i++ )
-			{
+			for( int i=0; i<fpb.getChildBlocks().size() && i<len; i++ ) {
 				ProgramBlock lpb = fpb.getChildBlocks().get(i);
 				StatementBlock lsb = fs.getBody().get(i);
 				node.addChild( rCreateAbstractOptNode(lsb,lpb,vars,false, memo) );

http://git-wip-us.apache.org/repos/asf/systemml/blob/8dbc9302/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java
index 9456703..b8da25a 100644
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java
@@ -255,7 +255,8 @@ public class OptimizerRuleBased extends Optimizer
 		LOG.debug(getOptMode()+" OPT: estimated new mem (serial exec) M="+toMB(M1) );
 		
 		//determine memory consumption for what-if: all-cp or partitioned 
-		double M2 = _cost.getEstimate(TestMeasure.MEMORY_USAGE, pn, LopProperties.ExecType.CP);
+		double M2 = pn.isCPOnly() ? M1 :
+			_cost.getEstimate(TestMeasure.MEMORY_USAGE, pn, LopProperties.ExecType.CP);
 		LOG.debug(getOptMode()+" OPT: estimated new mem (serial exec, all CP) M="+toMB(M2) );
 		double M3 = _cost.getEstimate(TestMeasure.MEMORY_USAGE, pn, true);
 		LOG.debug(getOptMode()+" OPT: estimated new mem (cond partitioning) M="+toMB(M3) );
@@ -898,17 +899,21 @@ public class OptimizerRuleBased extends Optimizer
 		return requiresRecompile;
 	}
 
-	protected boolean isLargeProblem(OptNode pn, double M0)
+	protected boolean isLargeProblem(OptNode pn, double M)
 	{
-		return ((_N >= PROB_SIZE_THRESHOLD_REMOTE || _Nmax >= 10 * PROB_SIZE_THRESHOLD_REMOTE )
-				&& M0 > PROB_SIZE_THRESHOLD_MB ); //original operations at least larger than 256MB
+		//TODO get a proper time estimate based to capture compute-intensive scenarios
+		
+		//rule-based decision based on number of outer iterations or maximum number of
+		//inner iterations (w/ appropriately scaled minimum data size threshold); 
+		return (_N >= PROB_SIZE_THRESHOLD_REMOTE && M > PROB_SIZE_THRESHOLD_MB)
+			|| (_Nmax >= 10 * PROB_SIZE_THRESHOLD_REMOTE && M > PROB_SIZE_THRESHOLD_MB/10);
 	}
 
 	protected boolean isCPOnlyPossible( OptNode n, double memBudget ) 
 		throws DMLRuntimeException
 	{
 		ExecType et = n.getExecType();
-		boolean ret = ( et == ExecType.CP);		
+		boolean ret = ( et == ExecType.CP);
 		
 		if( n.isLeaf() && et == getRemoteExecType() )
 		{