You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by ni...@apache.org on 2018/10/11 21:33:21 UTC

systemml git commit: [SYSTEMML-540] Allow user to generate an inlined DML script in Caffe2DML

Repository: systemml
Updated Branches:
  refs/heads/master 11c67055a -> ef1945d70


[SYSTEMML-540] Allow user to generate an inlined DML script in Caffe2DML

- The inlining code is generic enough to be extended to perform parser-level inlining. This commit allows us to compare the tradeoffs of performing script-level inlining v/s hop-level inlining.
- Refactored DMLParserWrapper and also added javadoc.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/ef1945d7
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/ef1945d7
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/ef1945d7

Branch: refs/heads/master
Commit: ef1945d70a85df4f646c315d06a1a094dad6ebb2
Parents: 11c6705
Author: Niketan Pansare <np...@us.ibm.com>
Authored: Thu Oct 11 14:28:59 2018 -0700
Committer: Niketan Pansare <np...@us.ibm.com>
Committed: Thu Oct 11 14:32:35 2018 -0700

----------------------------------------------------------------------
 .../parser/common/CustomErrorListener.java      |   8 +
 .../sysml/parser/dml/DMLParserWrapper.java      | 130 ++-
 .../java/org/apache/sysml/parser/dml/Dml.g4     |   6 +-
 .../sysml/parser/dml/DmlPreprocessor.java       |  13 +-
 .../apache/sysml/parser/dml/InlineHelper.java   | 798 +++++++++++++++++++
 .../sysml/parser/dml/InlineableMethods.java     |  98 +++
 .../controlprogram/caching/CacheableData.java   |   9 +-
 .../gpu/context/GPUMemoryManager.java           |   2 +-
 src/main/python/systemml/mllearn/estimators.py  |   5 +-
 .../org/apache/sysml/api/dl/Caffe2DML.scala     |   8 +-
 .../org/apache/sysml/api/dl/CaffeLayer.scala    |  12 +-
 .../org/apache/sysml/api/dl/CaffeSolver.scala   |  35 +-
 .../org/apache/sysml/api/dl/DMLGenerator.scala  |  36 +-
 .../scala/org/apache/sysml/api/dl/Utils.scala   |  58 +-
 14 files changed, 1114 insertions(+), 104 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/ef1945d7/src/main/java/org/apache/sysml/parser/common/CustomErrorListener.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/common/CustomErrorListener.java b/src/main/java/org/apache/sysml/parser/common/CustomErrorListener.java
index 2af5f69..b82afc9 100644
--- a/src/main/java/org/apache/sysml/parser/common/CustomErrorListener.java
+++ b/src/main/java/org/apache/sysml/parser/common/CustomErrorListener.java
@@ -22,6 +22,7 @@ package org.apache.sysml.parser.common;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.List;
+import java.util.Set;
 
 import org.antlr.v4.runtime.BaseErrorListener;
 import org.antlr.v4.runtime.RecognitionException;
@@ -38,6 +39,9 @@ public class CustomErrorListener extends BaseErrorListener {
 	private boolean atLeastOneError = false;
 	private boolean atLeastOneWarning = false;
 	private String currentFileName = null;
+	
+	// Names of user internal and external functions definitions
+	public Set<String> functions;
 
 	/**
 	 * List of parse issues.
@@ -55,6 +59,10 @@ public class CustomErrorListener extends BaseErrorListener {
 	public void unsetCurrentFileName() {
 		currentFileName = null;
 	}
+	
+	public Set<String> getFunctionDefs() {
+		return functions;
+	}
 
 	/**
 	 * Validation error occurred. Add the error to the list of parse issues.

http://git-wip-us.apache.org/repos/asf/systemml/blob/ef1945d7/src/main/java/org/apache/sysml/parser/dml/DMLParserWrapper.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/dml/DMLParserWrapper.java b/src/main/java/org/apache/sysml/parser/dml/DMLParserWrapper.java
index 1d7daa1..9b3f8c4 100644
--- a/src/main/java/org/apache/sysml/parser/dml/DMLParserWrapper.java
+++ b/src/main/java/org/apache/sysml/parser/dml/DMLParserWrapper.java
@@ -23,12 +23,14 @@ import java.io.ByteArrayInputStream;
 import java.io.FileNotFoundException;
 import java.io.IOException;
 import java.io.InputStream;
+import java.util.HashMap;
 import java.util.Map;
 
 import org.antlr.v4.runtime.ANTLRInputStream;
 import org.antlr.v4.runtime.BailErrorStrategy;
 import org.antlr.v4.runtime.CommonTokenStream;
 import org.antlr.v4.runtime.DefaultErrorStrategy;
+import org.antlr.v4.runtime.TokenStreamRewriter;
 import org.antlr.v4.runtime.atn.PredictionMode;
 import org.antlr.v4.runtime.misc.ParseCancellationException;
 import org.antlr.v4.runtime.tree.ParseTree;
@@ -74,6 +76,15 @@ import org.apache.sysml.parser.dml.DmlParser.StatementContext;
 public class DMLParserWrapper extends ParserWrapper
 {
 	private static final Log LOG = LogFactory.getLog(DMLScript.class.getName());
+	
+	// Rewriter is only used in getInlineableMethods
+	private TokenStreamRewriter rewriter = null;
+	
+	// The below fields are set in the createAST method
+	// Can be null or the path to the DML file
+	private String fileName; 
+	// Can be null or the DML script. Note: both fileName and DML script should not be null
+	private String dmlScript;
 
 	/**
 	 * Parses the passed file with command line parameters. You can either pass both (local file) or just dmlScript (hdfs) or just file name (import command)
@@ -88,17 +99,72 @@ public class DMLParserWrapper extends ParserWrapper
 	}
 	
 	/**
-	 * This function is supposed to be called directly only from DmlSyntacticValidator when it encounters 'import'
-	 * @param fileName script file name
-	 * @param dmlScript script file contents
-	 * @param sourceNamespace namespace from source statement
-	 * @param argVals script arguments
-	 * @return dml program, or null if at least one error
+	 * Performs preprocess using DmlPreprocessor listener class.
+	 * 
+	 * @param tree parse tree generated by createAST method
+	 * @param errorListener listener that captures potential syntactic errors 
+	 * @return a parse tree walker to perform further validation
 	 */
-	public DMLProgram doParse(String fileName, String dmlScript, String sourceNamespace, Map<String,String> argVals) {
-		DMLProgram dmlPgm = null;
+	ParseTreeWalker preprocess(ParseTree tree, CustomErrorListener errorListener) {
+		ParseTreeWalker walker = new ParseTreeWalker();
+		// Get list of function definitions which take precedence over built-in functions if same name
+		walker.walk(new DmlPreprocessor(errorListener),  tree);
+		return walker;
+	}
+	
+	/**
+	 * Get the inline-able methods
+	 * 
+	 * @param fileName1 can be null or the path to the DML file
+	 * @param dmlScript1 can be null or the DML script. Note, both fileName and DML script should not be null.
+	 * @param sourceNamespace source namespace
+	 * @param argVals command-line arguments
+	 * @return hashmap of inline-able methods
+	 */
+	public HashMap<String, InlineableMethods> getInlineableMethods(String fileName1, String dmlScript1, String sourceNamespace, Map<String,String> argVals) {
+		// Create AST and do preprocessing
+		CustomErrorListener errorListener = new CustomErrorListener();
+		ParseTree tree = createAST(fileName1, dmlScript1, sourceNamespace, argVals, errorListener, true);
+		ParseTreeWalker walker = preprocess(tree, errorListener);
+				
+		// Note: this method uses InlineHelper as a listener to perform rewriting of local variables
+		// It does so in two phases:
+		// Phase 1. Rewriting phase where local variables are rewritten by adding a prefix.
+		// Phase 2. Capture the body of the functions using InlineableMethods class
+		
+		// Rewrite all the local variables by adding prefix 
+		InlineHelper validator = new InlineHelper(errorListener, argVals, sourceNamespace, errorListener.getFunctionDefs(), rewriter);
+		validator.setPhase(true);
+		walker.walk(validator, tree);
 		
-		ANTLRInputStream in;
+		// Use the rewritten text as the new DML script and create AST again
+		fileName = null; dmlScript = rewriter.getText();
+		errorListener = new CustomErrorListener();
+		tree = createAST(fileName, dmlScript, sourceNamespace, argVals, errorListener, true);
+		walker = preprocess(tree, errorListener);
+				
+		// Put the content of rewritten function body in the inlineMap
+		validator.setPhase(false);
+		walker.walk(validator, tree);
+		
+		return validator.inlineMap;
+	}
+	
+	/**
+	 * Create an ANTLR parse tree for the input DML script
+	 * 
+	 * @param fileName1 can be null or the path to the DML file
+	 * @param dmlScript1 can be null or the DML script. Note, both fileName and DML script should not be null.
+	 * @param sourceNamespace source namespace
+	 * @param argVals command-line arguments
+	 * @param errorListener listener that captures potential syntactic errors
+	 * @param performRewriting should perform rewriting of tokens
+	 * @return a parse tree
+	 */
+	private ParseTree createAST(String fileName1, String dmlScript1, String sourceNamespace, Map<String,String> argVals, CustomErrorListener errorListener, boolean performRewriting) {
+		ANTLRInputStream in = null;
+		this.fileName = fileName1;
+		this.dmlScript = dmlScript1;
 		try {
 			if(dmlScript == null) {
 				dmlScript = readDMLScript(fileName, LOG);
@@ -113,13 +179,13 @@ public class DMLParserWrapper extends ParserWrapper
 		} catch (LanguageException e) {
 			throw new ParseException(e.getMessage(), e);
 		}
-
-		ProgramrootContext ast = null;
-		CustomErrorListener errorListener = new CustomErrorListener();
 		
+		ProgramrootContext ast = null;
 		try {
 			DmlLexer lexer = new DmlLexer(in);
 			CommonTokenStream tokens = new CommonTokenStream(lexer);
+			if(performRewriting)
+				rewriter = new TokenStreamRewriter(tokens);
 			DmlParser antlr4Parser = new DmlParser(tokens);
 			
 			boolean tryOptimizedParsing = false; // For now no optimization, since it is not able to parse integer value. 
@@ -163,19 +229,31 @@ public class DMLParserWrapper extends ParserWrapper
 		catch(Exception e) {
 			throw new ParseException("ERROR: Cannot parse the program:" + fileName, e);
 		}
+		return ast;
+	}
+	
+	
+	
+	/**
+	 * This function is supposed to be called directly only from DmlSyntacticValidator when it encounters 'import'
+	 * 
+	 * @param fileName1 script file name
+	 * @param dmlScript1 script file contents
+	 * @param sourceNamespace namespace from source statement
+	 * @param argVals script arguments
+	 * @return dml program, or null if at least one error
+	 */
+	public DMLProgram doParse(String fileName1, String dmlScript1, String sourceNamespace, Map<String,String> argVals) {
+		// Create AST and do preprocessing
+		CustomErrorListener errorListener = new CustomErrorListener();
+		ParseTree tree = createAST(fileName1, dmlScript1, sourceNamespace, argVals, errorListener, false);
+		ParseTreeWalker walker = preprocess(tree, errorListener);
 		
-
-		// Now convert the parse tree into DMLProgram
-		// Do syntactic validation while converting 
-		ParseTree tree = ast;
-		// And also do syntactic validation
-		ParseTreeWalker walker = new ParseTreeWalker();
-		// Get list of function definitions which take precedence over built-in functions if same name
-		DmlPreprocessor prep = new DmlPreprocessor(errorListener);
-		walker.walk(prep,  tree);
-		// Syntactic validation
-		DmlSyntacticValidator validator = new DmlSyntacticValidator(errorListener, argVals, sourceNamespace, prep.getFunctionDefs());
+		// Perform syntactic validation using DmlSyntacticValidator listener
+		DmlSyntacticValidator validator = new DmlSyntacticValidator(errorListener, argVals, sourceNamespace, errorListener.getFunctionDefs());
 		walker.walk(validator, tree);
+		
+		// Check for parse issues and warning
 		errorListener.unsetCurrentFileName();
 		this.parseIssues = errorListener.getParseIssues();
 		this.atLeastOneWarning = errorListener.isAtLeastOneWarning();
@@ -186,9 +264,9 @@ public class DMLParserWrapper extends ParserWrapper
 		if (atLeastOneWarning) {
 			LOG.warn(CustomErrorListener.generateParseIssuesMessage(dmlScript, parseIssues));
 		}
-		dmlPgm = createDMLProgram(ast, sourceNamespace);
 		
-		return dmlPgm;
+		// Create and return the DML program
+		return createDMLProgram((ProgramrootContext)tree, sourceNamespace);
 	}
 	
 	private static DMLProgram createDMLProgram(ProgramrootContext ast, String sourceNamespace) {
@@ -255,4 +333,4 @@ public class DMLParserWrapper extends ParserWrapper
 		
 		return dmlPgm;
 	}
-}
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/systemml/blob/ef1945d7/src/main/java/org/apache/sysml/parser/dml/Dml.g4
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/dml/Dml.g4 b/src/main/java/org/apache/sysml/parser/dml/Dml.g4
index 46ee178..08e1288 100644
--- a/src/main/java/org/apache/sysml/parser/dml/Dml.g4
+++ b/src/main/java/org/apache/sysml/parser/dml/Dml.g4
@@ -216,6 +216,6 @@ COMMANDLINE_POSITION_ID: '$' DIGIT+;
 STRING: '"' ( ESC | ~[\\"] )*? '"' | '\'' ( ESC | ~[\\'] )*? '\'';
 fragment ESC : '\\' [btnfr"'\\] ;
 // Comments, whitespaces and new line
-LINE_COMMENT : '#' .*? '\r'? '\n' -> skip ;
-MULTILINE_BLOCK_COMMENT : '/*' .*? '*/' -> skip ;
-WHITESPACE : (' ' | '\t' | '\r' | '\n')+ -> skip ;
+LINE_COMMENT : '#' .*? '\r'? '\n' -> channel(HIDDEN) ;
+MULTILINE_BLOCK_COMMENT : '/*' .*? '*/' -> channel(HIDDEN) ;
+WHITESPACE : (' ' | '\t' | '\r' | '\n')+ -> channel(HIDDEN) ;

http://git-wip-us.apache.org/repos/asf/systemml/blob/ef1945d7/src/main/java/org/apache/sysml/parser/dml/DmlPreprocessor.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/dml/DmlPreprocessor.java b/src/main/java/org/apache/sysml/parser/dml/DmlPreprocessor.java
index 56eb8ca..49e96e7 100644
--- a/src/main/java/org/apache/sysml/parser/dml/DmlPreprocessor.java
+++ b/src/main/java/org/apache/sysml/parser/dml/DmlPreprocessor.java
@@ -20,7 +20,6 @@
 package org.apache.sysml.parser.dml;
 
 import java.util.HashSet;
-import java.util.Set;
 
 import org.antlr.v4.runtime.ParserRuleContext;
 import org.antlr.v4.runtime.Token;
@@ -84,16 +83,10 @@ import org.apache.sysml.parser.dml.DmlParser.WhileStatementContext;
 public class DmlPreprocessor implements DmlListener {
 
 	protected final CustomErrorListener errorListener;
-	// Names of user internal and external functions definitions
-	protected Set<String> functions;
 
 	public DmlPreprocessor(CustomErrorListener errorListener) {
 		this.errorListener = errorListener;
-		functions = new HashSet<>();
-	}
-
-	public Set<String> getFunctionDefs() {
-		return functions;
+		this.errorListener.functions = new HashSet<>();
 	}
 	
 	@Override
@@ -113,8 +106,8 @@ public class DmlPreprocessor implements DmlListener {
 	public void exitInternalFunctionDefExpression(InternalFunctionDefExpressionContext ctx) {}
 
 	protected void validateFunctionName(String name, ParserRuleContext ctx) {
-		if (!functions.contains(name)) {
-			functions.add(name);
+		if (!errorListener.functions.contains(name)) {
+			errorListener.functions.add(name);
 		}
 		else {
 			notifyErrorListeners("Function Name Conflict: '" + name + "' already defined in " + errorListener.getCurrentFileName(), ctx.start);

http://git-wip-us.apache.org/repos/asf/systemml/blob/ef1945d7/src/main/java/org/apache/sysml/parser/dml/InlineHelper.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/dml/InlineHelper.java b/src/main/java/org/apache/sysml/parser/dml/InlineHelper.java
new file mode 100644
index 0000000..34d886c
--- /dev/null
+++ b/src/main/java/org/apache/sysml/parser/dml/InlineHelper.java
@@ -0,0 +1,798 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysml.parser.dml;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Random;
+import java.util.Set;
+
+import org.antlr.v4.runtime.ParserRuleContext;
+import org.antlr.v4.runtime.Token;
+import org.antlr.v4.runtime.TokenStreamRewriter;
+import org.antlr.v4.runtime.tree.ErrorNode;
+import org.antlr.v4.runtime.tree.TerminalNode;
+import org.apache.sysml.parser.Expression;
+import org.apache.sysml.parser.ParameterExpression;
+import org.apache.sysml.parser.common.CommonSyntacticValidator;
+import org.apache.sysml.parser.common.CustomErrorListener;
+import org.apache.sysml.parser.dml.DmlParser.AccumulatorAssignmentStatementContext;
+import org.apache.sysml.parser.dml.DmlParser.AddSubExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.AssignmentStatementContext;
+import org.apache.sysml.parser.dml.DmlParser.AtomicExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.BooleanAndExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.BooleanNotExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.BooleanOrExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.BuiltinFunctionExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.CommandlineParamExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.CommandlinePositionExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.ConstDoubleIdExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.ConstFalseExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.ConstIntIdExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.ConstStringIdExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.ConstTrueExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.DataIdExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.ExternalFunctionDefExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.ForStatementContext;
+import org.apache.sysml.parser.dml.DmlParser.FunctionCallAssignmentStatementContext;
+import org.apache.sysml.parser.dml.DmlParser.FunctionCallMultiAssignmentStatementContext;
+import org.apache.sysml.parser.dml.DmlParser.IfStatementContext;
+import org.apache.sysml.parser.dml.DmlParser.IfdefAssignmentStatementContext;
+import org.apache.sysml.parser.dml.DmlParser.ImportStatementContext;
+import org.apache.sysml.parser.dml.DmlParser.IndexedExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.InternalFunctionDefExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.IterablePredicateColonExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.IterablePredicateSeqExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.MatrixDataTypeCheckContext;
+import org.apache.sysml.parser.dml.DmlParser.MatrixMulExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.Ml_typeContext;
+import org.apache.sysml.parser.dml.DmlParser.ModIntDivExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.MultDivExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.MultiIdExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.ParForStatementContext;
+import org.apache.sysml.parser.dml.DmlParser.ParameterizedExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.PathStatementContext;
+import org.apache.sysml.parser.dml.DmlParser.PowerExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.ProgramrootContext;
+import org.apache.sysml.parser.dml.DmlParser.RelationalExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.SimpleDataIdentifierExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.StatementContext;
+import org.apache.sysml.parser.dml.DmlParser.StrictParameterizedExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.StrictParameterizedKeyValueStringContext;
+import org.apache.sysml.parser.dml.DmlParser.TypedArgAssignContext;
+import org.apache.sysml.parser.dml.DmlParser.TypedArgNoAssignContext;
+import org.apache.sysml.parser.dml.DmlParser.UnaryExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.ValueTypeContext;
+import org.apache.sysml.parser.dml.DmlParser.WhileStatementContext;
+
+/**
+ * This class is used to generate inline-able methods.
+ * It does so in two phases:
+ * - Phase 1. Rewriting phase where local variables are rewritten by adding a prefix.
+ * - Phase 2. Capture the body of the functions using InlineableMethods class 
+ */
+public class InlineHelper extends CommonSyntacticValidator implements DmlListener {
+	final static String ARG_PREFIX;
+	static {
+		Random rand = new Random();
+		ARG_PREFIX = "INTERNAL_PREFIX_" + Math.abs(rand.nextLong()) + "_" + Math.abs(rand.nextLong()) + "_"; 
+	}
+	public HashMap<String, InlineableMethods> inlineMap = new HashMap<>();
+	TokenStreamRewriter rewriter;
+	
+	// Set internally
+	HashSet<String> variables = new HashSet<>();
+	String currentFunction = null;
+	boolean isRewritePhase;
+	
+	public InlineHelper(CustomErrorListener errorListener, Map<String, String> argVals, String sourceNamespace,
+			Set<String> prepFunctions, TokenStreamRewriter rewriter1) {
+		super(errorListener, argVals, sourceNamespace, prepFunctions);
+		rewriter = rewriter1;
+	}
+	
+	void setPhase(boolean isRewritePhase1) {
+		isRewritePhase = isRewritePhase1;
+	}
+	
+
+	@Override
+	public void enterInternalFunctionDefExpression(InternalFunctionDefExpressionContext ctx) {
+		currentFunction = ctx.name.getText();
+		variables.clear();
+	}
+	
+	@Override
+	public void exitInternalFunctionDefExpression(InternalFunctionDefExpressionContext ctx) {
+		if(!isRewritePhase) {
+			StringBuilder sb = new StringBuilder();
+			for(StatementContext stmt : ctx.body) {
+				sb.append(stmt.getText());
+				sb.append("\n");
+			}
+			ArrayList<String> inputArgs = new ArrayList<>(); 
+			for(TypedArgAssignContext in : ctx.inputParams) {
+				inputArgs.add(ARG_PREFIX + in.paramName.getText());
+			}
+			ArrayList<String> retVariables = new ArrayList<>();
+			for(TypedArgNoAssignContext out : ctx.outputParams) {
+				retVariables.add(ARG_PREFIX + out.paramName.getText());
+			}
+			
+			inlineMap.put(currentFunction, new InlineableMethods(currentFunction, sb.toString(), variables, inputArgs, retVariables));
+		}
+		currentFunction = null;
+		variables.clear();
+	}
+	
+	@Override
+	public void enterIndexedExpression(IndexedExpressionContext ctx) {
+		if(currentFunction != null && isRewritePhase) {
+			rewriter.insertBefore(ctx.name, " " + ARG_PREFIX);
+		}
+	}
+	
+	@Override
+	public void exitIndexedExpression(IndexedExpressionContext ctx) {
+		if(currentFunction != null)
+			variables.add(ctx.name.getText());
+	}
+	
+	@Override
+	public void enterSimpleDataIdentifierExpression(SimpleDataIdentifierExpressionContext ctx) {
+		if(currentFunction != null && isRewritePhase) {
+			rewriter.insertBefore(ctx.start, " " + ARG_PREFIX);
+			rewriter.insertAfter(ctx.stop, " ");
+		}
+	}
+	
+	@Override
+	public void exitSimpleDataIdentifierExpression(SimpleDataIdentifierExpressionContext ctx) {
+		if(currentFunction != null)
+			variables.add(ctx.getText());
+	}
+	
+	@Override
+	public void enterForStatement(ForStatementContext ctx) {
+		if(currentFunction != null && isRewritePhase) {
+			rewriter.insertBefore(ctx.iterVar, " " + ARG_PREFIX);
+			rewriter.insertAfter(ctx.iterVar, " ");
+		}
+	}
+	
+	@Override
+	public void enterParForStatement(ParForStatementContext ctx) {
+		if(currentFunction != null && isRewritePhase) {
+			rewriter.insertBefore(ctx.iterVar, " " + ARG_PREFIX);
+			rewriter.insertAfter(ctx.iterVar, " ");
+		}
+	}
+	
+	@Override
+	public void exitForStatement(ForStatementContext ctx) {
+		if(currentFunction != null)
+			variables.add(ctx.iterVar.getText());
+		if(currentFunction != null && isRewritePhase) {
+			if(ctx.body != null && ctx.body.size() > 0)
+			rewriter.insertBefore(ctx.body.get(0).start, "\n");
+			rewriter.insertAfter(ctx.stop, "\n");
+		}
+	}
+	
+	@Override
+	public void exitParForStatement(ParForStatementContext ctx) {
+		if(currentFunction != null)
+			variables.add(ctx.iterVar.getText());
+		if(currentFunction != null && isRewritePhase) {
+			if(ctx.body != null && ctx.body.size() > 0)
+				rewriter.insertBefore(ctx.body.get(0).start, "\n");
+			rewriter.insertAfter(ctx.stop, "\n");
+		}
+	}
+	
+
+	@Override
+	protected ConvertedDMLSyntax convertToDMLSyntax(ParserRuleContext ctx, String namespace, String functionName,
+			ArrayList<ParameterExpression> paramExpression, Token fnName) {
+		
+		return null;
+	}
+
+	@Override
+	public void enterAccumulatorAssignmentStatement(AccumulatorAssignmentStatementContext ctx) {
+		
+		
+	}
+
+	@Override
+	public void enterAddSubExpression(AddSubExpressionContext ctx) {
+		
+		
+	}
+
+	@Override
+	public void enterAssignmentStatement(AssignmentStatementContext ctx) {
+		
+		
+	}
+
+	@Override
+	public void enterAtomicExpression(AtomicExpressionContext ctx) {
+		
+		
+	}
+
+	@Override
+	public void enterBooleanAndExpression(BooleanAndExpressionContext ctx) {
+		
+		
+	}
+
+	@Override
+	public void enterBooleanNotExpression(BooleanNotExpressionContext ctx) {
+		
+		
+	}
+
+	@Override
+	public void enterBooleanOrExpression(BooleanOrExpressionContext ctx) {
+		
+		
+	}
+
+	@Override
+	public void enterBuiltinFunctionExpression(BuiltinFunctionExpressionContext ctx) {
+		
+		
+	}
+
+	@Override
+	public void enterCommandlineParamExpression(CommandlineParamExpressionContext ctx) {
+		
+		
+	}
+
+	@Override
+	public void enterCommandlinePositionExpression(CommandlinePositionExpressionContext ctx) {
+		
+		
+	}
+
+	@Override
+	public void enterConstDoubleIdExpression(ConstDoubleIdExpressionContext ctx) {
+		
+		
+	}
+
+	@Override
+	public void enterConstFalseExpression(ConstFalseExpressionContext ctx) {
+		
+		
+	}
+
+	@Override
+	public void enterConstIntIdExpression(ConstIntIdExpressionContext ctx) {
+		
+		
+	}
+
+	@Override
+	public void enterConstStringIdExpression(ConstStringIdExpressionContext ctx) {
+		
+		
+	}
+
+	@Override
+	public void enterConstTrueExpression(ConstTrueExpressionContext ctx) {
+		
+		
+	}
+
+	@Override
+	public void enterDataIdExpression(DataIdExpressionContext ctx) {
+		
+		
+	}
+
+	@Override
+	public void enterEveryRule(ParserRuleContext arg0) {
+		
+		
+	}
+
+	@Override
+	public void enterExternalFunctionDefExpression(ExternalFunctionDefExpressionContext ctx) {
+		
+		
+	}
+	
+
+	@Override
+	public void enterFunctionCallAssignmentStatement(FunctionCallAssignmentStatementContext ctx) {
+		
+		
+	}
+
+	@Override
+	public void enterFunctionCallMultiAssignmentStatement(FunctionCallMultiAssignmentStatementContext ctx) {
+		
+		
+	}
+
+	@Override
+	public void enterIfdefAssignmentStatement(IfdefAssignmentStatementContext ctx) {
+		
+		
+	}
+
+	@Override
+	public void enterIfStatement(IfStatementContext ctx) {
+		
+		
+	}
+
+	@Override
+	public void enterImportStatement(ImportStatementContext ctx) {
+		
+		
+	}
+
+	@Override
+	public void enterIterablePredicateColonExpression(IterablePredicateColonExpressionContext ctx) {
+		
+		
+	}
+
+	@Override
+	public void enterIterablePredicateSeqExpression(IterablePredicateSeqExpressionContext ctx) {
+		
+		
+	}
+
+	@Override
+	public void enterMatrixDataTypeCheck(MatrixDataTypeCheckContext ctx) {
+		
+		
+	}
+
+	@Override
+	public void enterMatrixMulExpression(MatrixMulExpressionContext ctx) {
+		
+		
+	}
+
+	@Override
+	public void enterMl_type(Ml_typeContext ctx) {
+		
+		
+	}
+
+	@Override
+	public void enterModIntDivExpression(ModIntDivExpressionContext ctx) {
+		
+		
+	}
+
+	@Override
+	public void enterMultDivExpression(MultDivExpressionContext ctx) {
+		
+		
+	}
+
+	@Override
+	public void enterMultiIdExpression(MultiIdExpressionContext ctx) {
+		
+		
+	}
+
+	@Override
+	public void enterParameterizedExpression(ParameterizedExpressionContext ctx) {
+		
+		
+	}
+
+	@Override
+	public void enterPathStatement(PathStatementContext ctx) {
+		
+		
+	}
+
+	@Override
+	public void enterPowerExpression(PowerExpressionContext ctx) {
+		
+		
+	}
+
+	@Override
+	public void enterProgramroot(ProgramrootContext ctx) {
+		
+		
+	}
+
+	@Override
+	public void enterRelationalExpression(RelationalExpressionContext ctx) {
+		
+		
+	}
+
+	@Override
+	public void enterStrictParameterizedExpression(StrictParameterizedExpressionContext ctx) {
+		
+		
+	}
+
+	@Override
+	public void enterStrictParameterizedKeyValueString(StrictParameterizedKeyValueStringContext ctx) {
+		
+		
+	}
+
+	@Override
+	public void enterTypedArgAssign(TypedArgAssignContext ctx) {
+		
+		
+	}
+
+	@Override
+	public void enterTypedArgNoAssign(TypedArgNoAssignContext ctx) {
+		
+		
+	}
+
+	@Override
+	public void enterUnaryExpression(UnaryExpressionContext ctx) {
+		
+		
+	}
+
+	@Override
+	public void enterValueType(ValueTypeContext ctx) {
+		
+		
+	}
+
+	@Override
+	public void enterWhileStatement(WhileStatementContext ctx) {
+		
+		
+	}
+
+	@Override
+	public void exitAccumulatorAssignmentStatement(AccumulatorAssignmentStatementContext ctx) {
+		if(currentFunction != null && isRewritePhase) {
+			rewriter.insertAfter(ctx.stop, ";\n");
+		}
+		
+	}
+
+	@Override
+	public void exitAddSubExpression(AddSubExpressionContext ctx) {
+		
+		
+	}
+
+	@Override
+	public void exitAssignmentStatement(AssignmentStatementContext ctx) {
+		if(currentFunction != null && isRewritePhase) {
+			rewriter.insertAfter(ctx.stop, ";\n");
+		}
+		
+	}
+
+	@Override
+	public void exitAtomicExpression(AtomicExpressionContext ctx) {
+		
+		
+	}
+
+	@Override
+	public void exitBooleanAndExpression(BooleanAndExpressionContext ctx) {
+		
+		
+	}
+
+	@Override
+	public void exitBooleanNotExpression(BooleanNotExpressionContext ctx) {
+		
+		
+	}
+
+	@Override
+	public void exitBooleanOrExpression(BooleanOrExpressionContext ctx) {
+		
+		
+	}
+
+	@Override
+	public void exitBuiltinFunctionExpression(BuiltinFunctionExpressionContext ctx) {
+		
+		
+	}
+
+	@Override
+	public void exitCommandlineParamExpression(CommandlineParamExpressionContext ctx) {
+		
+		
+	}
+
+	@Override
+	public void exitCommandlinePositionExpression(CommandlinePositionExpressionContext ctx) {
+		
+		
+	}
+
+	@Override
+	public void exitConstDoubleIdExpression(ConstDoubleIdExpressionContext ctx) {
+		
+		
+	}
+
+	@Override
+	public void exitConstFalseExpression(ConstFalseExpressionContext ctx) {
+		
+		
+	}
+
+	@Override
+	public void exitConstIntIdExpression(ConstIntIdExpressionContext ctx) {
+		
+		
+	}
+
+	@Override
+	public void exitConstStringIdExpression(ConstStringIdExpressionContext ctx) {
+		
+		
+	}
+
+	@Override
+	public void exitConstTrueExpression(ConstTrueExpressionContext ctx) {
+		
+		
+	}
+
+	@Override
+	public void exitDataIdExpression(DataIdExpressionContext ctx) {
+		
+		
+	}
+
+	@Override
+	public void exitEveryRule(ParserRuleContext arg0) {
+		
+		
+	}
+
+	@Override
+	public void exitExternalFunctionDefExpression(ExternalFunctionDefExpressionContext ctx) {
+		
+		
+	}
+
+	
+
+	@Override
+	public void exitFunctionCallAssignmentStatement(FunctionCallAssignmentStatementContext ctx) {
+		if(currentFunction != null && isRewritePhase) {
+			rewriter.insertAfter(ctx.stop, ";\n");
+		}
+		
+	}
+
+	@Override
+	public void exitFunctionCallMultiAssignmentStatement(FunctionCallMultiAssignmentStatementContext ctx) {
+		if(currentFunction != null && isRewritePhase) {
+			rewriter.insertAfter(ctx.stop, ";\n");
+		}
+		
+	}
+
+	@Override
+	public void exitIfdefAssignmentStatement(IfdefAssignmentStatementContext ctx) {
+		if(currentFunction != null && isRewritePhase) {
+			rewriter.insertAfter(ctx.stop, ";\n");
+		}
+		
+	}
+
+	@Override
+	public void exitIfStatement(IfStatementContext ctx) {
+		if(currentFunction != null && isRewritePhase) {
+			if(ctx.ifBody != null && ctx.ifBody.size() > 0)
+				rewriter.insertBefore(ctx.ifBody.get(0).start, "\n");
+			if(ctx.elseBody != null && ctx.elseBody.size() > 0)
+				rewriter.insertBefore(ctx.elseBody.get(0).start, "\n");
+			rewriter.insertAfter(ctx.stop, "\n");
+		}
+		
+	}
+
+	@Override
+	public void exitImportStatement(ImportStatementContext ctx) {
+		if(currentFunction != null && isRewritePhase) {
+			rewriter.insertAfter(ctx.stop, ";\n");
+		}
+		
+	}
+
+	@Override
+	public void exitIterablePredicateColonExpression(IterablePredicateColonExpressionContext ctx) {
+		
+		
+	}
+
+	@Override
+	public void exitIterablePredicateSeqExpression(IterablePredicateSeqExpressionContext ctx) {
+		
+		
+	}
+
+	@Override
+	public void exitMatrixDataTypeCheck(MatrixDataTypeCheckContext ctx) {
+		
+		
+	}
+
+	@Override
+	public void exitMatrixMulExpression(MatrixMulExpressionContext ctx) {
+		
+		
+	}
+
+	@Override
+	public void exitMl_type(Ml_typeContext ctx) {
+		
+		
+	}
+
+	@Override
+	public void exitModIntDivExpression(ModIntDivExpressionContext ctx) {
+		
+		
+	}
+
+	@Override
+	public void exitMultDivExpression(MultDivExpressionContext ctx) {
+		
+		
+	}
+
+	@Override
+	public void exitMultiIdExpression(MultiIdExpressionContext ctx) {
+		
+		
+	}
+
+	@Override
+	public void exitParameterizedExpression(ParameterizedExpressionContext ctx) {
+		
+		
+	}
+
+	@Override
+	public void exitPathStatement(PathStatementContext ctx) {
+		if(currentFunction != null && isRewritePhase) {
+			rewriter.insertAfter(ctx.stop, ";\n");
+		}
+		
+	}
+
+	@Override
+	public void exitPowerExpression(PowerExpressionContext ctx) {
+		
+		
+	}
+
+	@Override
+	public void exitProgramroot(ProgramrootContext ctx) {
+		
+		
+	}
+
+	@Override
+	public void exitRelationalExpression(RelationalExpressionContext ctx) {
+		
+		
+	}
+	
+	@Override
+	public void exitStrictParameterizedExpression(StrictParameterizedExpressionContext ctx) {
+		
+		
+	}
+
+	@Override
+	public void exitStrictParameterizedKeyValueString(StrictParameterizedKeyValueStringContext ctx) {
+		
+		
+	}
+
+	@Override
+	public void exitTypedArgAssign(TypedArgAssignContext ctx) {
+		
+		
+	}
+
+	@Override
+	public void exitTypedArgNoAssign(TypedArgNoAssignContext ctx) {
+		
+		
+	}
+
+	@Override
+	public void exitUnaryExpression(UnaryExpressionContext ctx) {
+		
+		
+	}
+
+	@Override
+	public void exitValueType(ValueTypeContext ctx) {
+		
+		
+	}
+
+	@Override
+	public void exitWhileStatement(WhileStatementContext ctx) {
+		if(currentFunction != null && isRewritePhase) {
+			if(ctx.body != null && ctx.body.size() > 0)
+				rewriter.insertBefore(ctx.body.get(0).start, "\n");
+			rewriter.insertAfter(ctx.stop, "\n");
+		}
+	}
+
+	@Override
+	public String falseStringLiteral() {
+		
+		return null;
+	}
+
+	@Override
+	protected Expression handleLanguageSpecificFunction(ParserRuleContext ctx, String functionName,
+			ArrayList<ParameterExpression> paramExpressions) {
+		
+		return null;
+	}
+
+	@Override
+	public String namespaceResolutionOp() {
+		
+		return null;
+	}
+
+	@Override
+	public String trueStringLiteral() {
+		
+		return null;
+	}
+
+	@Override
+	public void visitErrorNode(ErrorNode arg0) {
+		
+		
+	}
+
+	@Override
+	public void visitTerminal(TerminalNode arg0) {
+		
+		
+	}
+
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/systemml/blob/ef1945d7/src/main/java/org/apache/sysml/parser/dml/InlineableMethods.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/dml/InlineableMethods.java b/src/main/java/org/apache/sysml/parser/dml/InlineableMethods.java
new file mode 100644
index 0000000..d3b0d11
--- /dev/null
+++ b/src/main/java/org/apache/sysml/parser/dml/InlineableMethods.java
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysml.parser.dml;
+
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Random;
+
+import org.apache.sysml.runtime.DMLRuntimeException;
+
+/** 
+ * This class is a simple container class used to hold the function to be inlined.
+ * It contains the function name, body and also the input and return arguments.
+ * The user invokes getInlinedDML method to get the inlined DML code.
+ */
+public class InlineableMethods {
+	ArrayList<String> _variables;
+	final String _body;
+	final String _fnName;
+	final ArrayList<String> _inputArgs;
+	final ArrayList<String> _retVariables;
+	static int CALLER_ID = 1;
+	
+	public InlineableMethods(String fnName, String body, HashSet<String> variables, ArrayList<String> inputArgs, ArrayList<String> retVariables) {
+		_fnName = fnName;
+		_body = body;
+		_variables = new ArrayList<String>(variables);
+		_variables.sort(Comparator.comparing(String::length).reversed());
+		_inputArgs = inputArgs;
+		_retVariables = retVariables;
+	}
+	
+	public ArrayList<String> getLocalVariables() {
+		return _variables;
+	}
+	
+	private String _getInlinedDML(HashMap<String, String> actualArguments) {
+		String ret = _body;
+		int callerID = CALLER_ID++;
+		for(String var : _variables) {
+			String originalVarName = var.substring(InlineHelper.ARG_PREFIX.length());
+			if(actualArguments.containsKey(var)) {
+				ret = ret.replaceAll(var, actualArguments.get(var));
+			}
+			else {
+				// internal argument
+				ret = ret.replaceAll(var, LOCAL_ARG_PREFIX + _fnName + "_" + callerID + "_" + originalVarName);
+			}
+		}
+		return ret;
+	}
+	
+	public String getInlinedDML(ArrayList<String> actualInputArgs, ArrayList<String> actualRetVariables) {
+		HashMap<String, String> actualArguments = new HashMap<>();
+		if(actualInputArgs.size() != _inputArgs.size()) {
+			throw new DMLRuntimeException("Incorrect number of input arguments for the function " + _fnName + ": expected " 
+			+ _inputArgs.size() + " (" + String.join(", ", _inputArgs) + ") but found " + actualInputArgs.size() 
+			+ " (" + String.join(", ", actualInputArgs) + ")");
+		}
+		if(actualRetVariables.size() != _retVariables.size()) {
+			throw new DMLRuntimeException("Incorrect number of return variables for the function " + _fnName + ": expected " 
+			+ _retVariables.size() + " (" + String.join(", ", _retVariables) + ") but found " + actualRetVariables.size()
+			+ " (" + String.join(", ", actualRetVariables) + ")");
+		}
+		for(int i = 0; i < _inputArgs.size(); i++) {
+			actualArguments.put(_inputArgs.get(i), actualInputArgs.get(i));
+		}
+		for(int i = 0; i < _retVariables.size(); i++) {
+			actualArguments.put(_retVariables.get(i), actualRetVariables.get(i));
+		}
+		return _getInlinedDML(actualArguments);
+	}
+	
+	static final String LOCAL_ARG_PREFIX;
+	static {
+		Random rand = new Random();
+		LOCAL_ARG_PREFIX = "LOCAL_" + Math.abs(rand.nextLong()) + "_" + Math.abs(rand.nextLong());
+//		LOCAL_ARG_PREFIX = "LOCAL_";
+	}
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/systemml/blob/ef1945d7/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java b/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java
index 03bc3b3..15dd23e 100644
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java
@@ -307,7 +307,14 @@ public abstract class CacheableData<T extends CacheBlock> extends Data
 	}
 	
 	public MatrixCharacteristics getMatrixCharacteristics() {
-		return _metaData.getMatrixCharacteristics();
+		MatrixCharacteristics mc = _metaData.getMatrixCharacteristics();
+		if(mc.getRowsPerBlock() == -1) {
+			mc.setRowsPerBlock(OptimizerUtils.DEFAULT_BLOCKSIZE);
+		}
+		if(mc.getColsPerBlock() == -1) {
+			mc.setColsPerBlock(OptimizerUtils.DEFAULT_BLOCKSIZE);
+		}
+		return mc;
 	}
 
 	public abstract void refreshMetaData();

http://git-wip-us.apache.org/repos/asf/systemml/blob/ef1945d7/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.java
index 6772b4a..a08d4fd 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.java
@@ -298,7 +298,7 @@ public class GPUMemoryManager {
 				evictOrClear(sizeBasedUnlockedGPUObjects.get(), opcode);
 				A = cudaMallocNoWarn(tmpA, size, null);
 				if(A == null)
-					LOG.warn("cudaMalloc failed after clearing/evicting based on size.");
+					LOG.debug("cudaMalloc failed after clearing/evicting based on size.");
 				if(ConfigurationManager.isStatistics()) {
 					long totalTime = System.nanoTime() - t0;
 					GPUStatistics.cudaEvictTime.add(totalTime);

http://git-wip-us.apache.org/repos/asf/systemml/blob/ef1945d7/src/main/python/systemml/mllearn/estimators.py
----------------------------------------------------------------------
diff --git a/src/main/python/systemml/mllearn/estimators.py b/src/main/python/systemml/mllearn/estimators.py
index fbcd3e2..8a100b4 100644
--- a/src/main/python/systemml/mllearn/estimators.py
+++ b/src/main/python/systemml/mllearn/estimators.py
@@ -924,7 +924,7 @@ class Caffe2DML(BaseSystemMLClassifier):
             self.estimator.setWeightsToIgnore(ignore_weights)
 
     def set(self, debug=None, train_algo=None, test_algo=None, parallel_batches=None,
-            output_activations=None, perform_one_hot_encoding=None, parfor_parameters=None):
+            output_activations=None, perform_one_hot_encoding=None, parfor_parameters=None, inline_nn_library=None):
         """
         Set input to Caffe2DML
 
@@ -937,9 +937,12 @@ class Caffe2DML(BaseSystemMLClassifier):
         output_activations: (developer flag) directory to output activations of each layer as csv while prediction. To be used only in batch mode (default: None)
         perform_one_hot_encoding: should perform one-hot encoding in DML using table function (default: False)
         parfor_parameters: dictionary for parfor parameters when using allreduce-style algorithms (default: "")
+        inline_nn_library: whether to inline the NN library when generating DML using Caffe2DML (default: False)
         """
         if debug is not None:
             self.estimator.setInput("$debug", str(debug).upper())
+        if inline_nn_library is not None:
+            self.estimator.setInput("$inline_nn_library", str(inline_nn_library).upper())
         if train_algo is not None:
             self.estimator.setInput("$train_algo", str(train_algo).lower())
         if test_algo is not None:

http://git-wip-us.apache.org/repos/asf/systemml/blob/ef1945d7/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala b/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala
index 26e554f..8ddb1fe 100644
--- a/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala
+++ b/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala
@@ -119,8 +119,9 @@ object Caffe2DML {
   val LOG = LogFactory.getLog(classOf[Caffe2DML].getName())
   // ------------------------------------------------------------------------
   val USE_PLUS_EQ = true
-  def layerDir = "nn/layers/"
-  def optimDir = "nn/optim/"
+  def nnDir = "nn/"
+  def layerDir = nnDir + "layers/"
+  def optimDir = nnDir + "optim/"
 
   // Naming conventions:
   val X    = "X"; val y        = "y"; val batchSize = "BATCH_SIZE"; val numImages = "num_images"; val numValidationImages = "num_validation"
@@ -159,6 +160,7 @@ object Caffe2DML {
   val BATCH_ALGORITHM = "batch"
   val ALLREDUCE_ALGORITHM = "allreduce"
   val ALLREDUCE_PARALLEL_BATCHES_ALGORITHM = "allreduce_parallel_batches"
+  var INLINE_NN_LIBRARY = false
 }
 
 class Caffe2DML(val sc: SparkContext,
@@ -312,6 +314,7 @@ class Caffe2DML(val sc: SparkContext,
 
     // Flags passed by user
     val DEBUG_TRAINING = if (inputs.containsKey("$debug")) inputs.get("$debug").toLowerCase.toBoolean else false
+    Caffe2DML.INLINE_NN_LIBRARY = if (inputs.containsKey("$inline_nn_library")) inputs.get("$inline_nn_library").toLowerCase.toBoolean else false
     assign(tabDMLScript, "debug", if (DEBUG_TRAINING) "TRUE" else "FALSE")
     setDebugFlags(DEBUG_TRAINING)
 
@@ -721,6 +724,7 @@ class Caffe2DMLModel(val numClasses: String, val sc: SparkContext, val solver: C
     reset // Reset the state of DML generator for training script.
 
     val DEBUG_PREDICTION = if (estimator.inputs.containsKey("$debug")) estimator.inputs.get("$debug").toLowerCase.toBoolean else false
+    Caffe2DML.INLINE_NN_LIBRARY = if (estimator.inputs.containsKey("$inline_nn_library")) estimator.inputs.get("$inline_nn_library").toLowerCase.toBoolean else false
     assign(tabDMLScript, "debug", if (DEBUG_PREDICTION) "TRUE" else "FALSE")
     estimator.setDebugFlags(DEBUG_PREDICTION)
 

http://git-wip-us.apache.org/repos/asf/systemml/blob/ef1945d7/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala b/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala
index d664f6e..b290983 100644
--- a/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala
+++ b/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala
@@ -29,6 +29,7 @@ import caffe.Caffe.EltwiseParameter.EltwiseOp
 import org.apache.sysml.runtime.DMLRuntimeException;
 import java.util.ArrayList
 import caffe.Caffe.PoolingParameter.PoolMethod
+import scala.collection.JavaConverters._
 
 trait CaffeLayer extends BaseDMLGenerator {
   // -------------------------------------------------
@@ -125,7 +126,7 @@ trait CaffeLayer extends BaseDMLGenerator {
   // The layers that have a corresponding dml script call this method.
   // Assumption: the first variable of resultVariables is always dX
   def invokeBackward(dmlScript: StringBuilder, outSuffix: String, resultVariables: List[String], arguments: String*): Unit = {
-    invoke(dmlScript, sourceFileName + "::", resultVariables.map(_ + outSuffix), "backward", arguments.toList, false)
+    Utils.invoke(Caffe2DML.layerDir, dmlScript, sourceFileName + "::", resultVariables.map(_ + outSuffix), "backward", arguments.toList, false)
     val bottomLayerIDs = net.getBottomLayers(param.getName).map(l => net.getCaffeLayer(l).id)
     dmlScript.append("; ")
     bottomLayerIDs.map(bottomLayerID => dmlScript.append(dX(bottomLayerID) + outSuffix + " = " + resultVariables(0) + outSuffix + "; "))
@@ -140,6 +141,13 @@ trait CaffeLayer extends BaseDMLGenerator {
     dmlScript.append("\n")
   }
   // --------------------------------------------------------------------------------------
+  
+  def invoke(dmlScript: StringBuilder, namespace1: String, returnVariables: List[String], functionName: String, arguments: List[String]): Unit =
+    Utils.invoke(Caffe2DML.layerDir, dmlScript, namespace1, returnVariables, functionName, arguments, true)
+  def invoke(dmlScript: StringBuilder, namespace1: String, returnVariables: List[String], functionName: String, appendNewLine: Boolean, arguments: String*): Unit =
+    Utils.invoke(Caffe2DML.layerDir, dmlScript, namespace1, returnVariables, functionName, arguments.toList, appendNewLine)
+  def invoke(dmlScript: StringBuilder, namespace1: String, returnVariables: List[String], functionName: String, arguments: String*): Unit =
+    Utils.invoke(Caffe2DML.layerDir, dmlScript, namespace1, returnVariables, functionName, arguments.toList, true)
 }
 
 trait IsLossLayer extends CaffeLayer {
@@ -1603,4 +1611,4 @@ class DeConvolution(val param: LayerParameter, val id: Int, val net: CaffeNetwor
     if (convParam.hasPadW) convParam.getPadW.toString
     else if (convParam.getPadCount > 0) convParam.getPad(0).toString
     else "0"
-}
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/systemml/blob/ef1945d7/src/main/scala/org/apache/sysml/api/dl/CaffeSolver.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/dl/CaffeSolver.scala b/src/main/scala/org/apache/sysml/api/dl/CaffeSolver.scala
index 8559c60..da963b9 100644
--- a/src/main/scala/org/apache/sysml/api/dl/CaffeSolver.scala
+++ b/src/main/scala/org/apache/sysml/api/dl/CaffeSolver.scala
@@ -67,10 +67,12 @@ trait CaffeSolver {
       val hasDecayMult = layer.param.getParamList != null && layer.param.getParamList.size >= 1 && layer.param.getParamList.get(0).hasDecayMult
       val newLambda = if(hasDecayMult) layer.param.getParamList.get(0).getDecayMult * lambda else lambda
       
-      dmlScript.append("\t").append(layer.dWeight + "_reg = " + regularizationSource + "::backward(" + layer.weight + ", " + newLambda + ")\n")
+      Utils.invoke(Caffe2DML.layerDir, dmlScript, regularizationSource + "::", List[String](layer.dWeight + "_reg"), "backward", 
+          List[String](layer.weight, "" + newLambda), true)
       dmlScript.append("\t").append(layer.dWeight + " = " + layer.dWeight + " + " + layer.dWeight + "_reg\n")
       if(layer.shouldUpdateExtraWeight) {
-        dmlScript.append("\t").append(layer.dExtraWeight + "_reg = " + regularizationSource + "::backward(" + layer.extraWeight + ", " + newLambda + ")\n")
+        Utils.invoke(Caffe2DML.layerDir, dmlScript, regularizationSource + "::", List[String](layer.dExtraWeight + "_reg"), "backward", 
+          List[String](layer.extraWeight, "" + newLambda), true)
         dmlScript.append("\t").append(layer.dExtraWeight + " = " + layer.dExtraWeight + " + " + layer.dExtraWeight + "_reg\n")
       }
     }
@@ -339,32 +341,19 @@ class Nesterov(regularizationType:String = "L2", lambda: Double = 5e-04, momentu
    *      input v.
    */
   def update(dmlScript: StringBuilder, layer: CaffeLayer): Unit = {
-    val fn            = if (Caffe2DML.USE_NESTEROV_UDF) "update_nesterov" else "sgd_nesterov::update"
-    val lastParameter = if (Caffe2DML.USE_NESTEROV_UDF) (", " + lambda) else ""
+    
     if (!Caffe2DML.USE_NESTEROV_UDF) {
       regularization_update(regularizationType, lambda, dmlScript, layer)
     }
     if (layer.shouldUpdateWeight)
-      dmlScript
-        .append("\t")
-        .append(
-          "[" + commaSep(layer.weight, layer.weight + "_v") + "] " +
-          "= " + fn + "(" + commaSep(layer.weight, layer.dWeight, getWeightLr(layer), momentum.toString, layer.weight + "_v") + lastParameter + ")\n"
-        )
+      Utils.invoke(Caffe2DML.optimDir, dmlScript, "sgd_nesterov::", List[String](layer.weight, layer.weight + "_v"), "update", 
+          List[String](layer.weight, layer.dWeight, getWeightLr(layer), momentum.toString, layer.weight + "_v"), true)
     if (layer.shouldUpdateExtraWeight)
-      dmlScript
-        .append("\t")
-        .append(
-          "[" + commaSep(layer.extraWeight, layer.extraWeight + "_v") + "] " +
-          "= " + fn + "(" + commaSep(layer.extraWeight, layer.dExtraWeight, getWeightLr(layer), momentum.toString, layer.extraWeight + "_v") + lastParameter + ")\n"
-        )
+      Utils.invoke(Caffe2DML.optimDir, dmlScript, "sgd_nesterov::", List[String](layer.extraWeight, layer.extraWeight + "_v"), "update", 
+          List[String](layer.extraWeight, layer.dExtraWeight, getWeightLr(layer), momentum.toString, layer.extraWeight + "_v"), true)
     if (layer.shouldUpdateBias)
-      dmlScript
-        .append("\t")
-        .append(
-          "[" + commaSep(layer.bias, layer.bias + "_v") + "] " +
-          "= " + fn + "(" + commaSep(layer.bias, layer.dBias, getBiasLr(layer), momentum.toString, layer.bias + "_v") + lastParameter + ")\n"
-        )
+      Utils.invoke(Caffe2DML.optimDir, dmlScript, "sgd_nesterov::", List[String](layer.bias, layer.bias + "_v"), "update", 
+          List[String](layer.bias, layer.dBias, getBiasLr(layer), momentum.toString, layer.bias + "_v"), true)
   }
   def init(dmlScript: StringBuilder, layer: CaffeLayer): Unit = {
     if (layer.shouldUpdateWeight) dmlScript.append(layer.weight + "_v = sgd_nesterov::init(" + layer.weight + ")\n")
@@ -372,4 +361,4 @@ class Nesterov(regularizationType:String = "L2", lambda: Double = 5e-04, momentu
     if (layer.shouldUpdateBias) dmlScript.append(layer.bias + "_v = sgd_nesterov::init(" + layer.bias + ")\n")
   }
   def sourceFileName: String = "sgd_nesterov"
-}
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/systemml/blob/ef1945d7/src/main/scala/org/apache/sysml/api/dl/DMLGenerator.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/dl/DMLGenerator.scala b/src/main/scala/org/apache/sysml/api/dl/DMLGenerator.scala
index 60396f1..59c75ad 100644
--- a/src/main/scala/org/apache/sysml/api/dl/DMLGenerator.scala
+++ b/src/main/scala/org/apache/sysml/api/dl/DMLGenerator.scala
@@ -75,39 +75,6 @@ trait BaseDMLGenerator {
     sum(dmlScript, rhsVars)
     dmlScript.append("\n")
   }
-  def invoke(dmlScript: StringBuilder, namespace1: String, returnVariables: List[String], functionName: String, arguments: List[String]): Unit =
-    invoke(dmlScript, namespace1, returnVariables, functionName, arguments, true)
-  def invoke(dmlScript: StringBuilder, namespace1: String, returnVariables: List[String], functionName: String, arguments: List[String], appendNewLine: Boolean): Unit = {
-    if (returnVariables.length == 0) throw new DMLRuntimeException("User-defined functions should have atleast one return value")
-    if (returnVariables.length > 1) dmlScript.append("[")
-    dmlScript.append(returnVariables(0))
-    if (returnVariables.length > 1) {
-      for (i <- 1 until returnVariables.length) {
-        dmlScript.append(",").append(returnVariables(i))
-      }
-      dmlScript.append("]")
-    }
-    dmlScript.append(" = ")
-    dmlScript.append(namespace1)
-    dmlScript.append(functionName)
-    dmlScript.append("(")
-    if (arguments != null) {
-      if (arguments.length != 0)
-        dmlScript.append(arguments(0))
-      if (arguments.length > 1) {
-        for (i <- 1 until arguments.length) {
-          dmlScript.append(",").append(arguments(i))
-        }
-      }
-    }
-    dmlScript.append(")")
-    if (appendNewLine)
-      dmlScript.append("\n")
-  }
-  def invoke(dmlScript: StringBuilder, namespace1: String, returnVariables: List[String], functionName: String, appendNewLine: Boolean, arguments: String*): Unit =
-    invoke(dmlScript, namespace1, returnVariables, functionName, arguments.toList, appendNewLine)
-  def invoke(dmlScript: StringBuilder, namespace1: String, returnVariables: List[String], functionName: String, arguments: String*): Unit =
-    invoke(dmlScript, namespace1, returnVariables, functionName, arguments.toList, true)
   def rightIndexing(dmlScript: StringBuilder, lhsVar:String, rhsVar: String, rl: String, ru: String, cl: String=null, cu: String=null): StringBuilder = {
     dmlScript.append(lhsVar).append(" = ").append(rhsVar).append("[")
     if (rl != null && ru != null) dmlScript.append(rl).append(":").append(ru)
@@ -279,6 +246,7 @@ trait DMLGenerator extends SourceDMLGenerator with NextBatchGenerator {
     // Append source statements for layers as well as solver
     source(net, solver, if (isTraining) Array[String]("l1_reg") else null)
     source(net, solver, if (isTraining) Array[String]("l2_reg") else null)
+    source(dmlScript, numTabs, "util", Caffe2DML.nnDir)
 
     if (isTraining) {
       // Append external built-in function headers:
@@ -346,4 +314,4 @@ trait DMLGenerator extends SourceDMLGenerator with NextBatchGenerator {
 
   def updateMeanVarianceForBatchNorm(net: CaffeNetwork, value: Boolean): Unit =
     net.getLayers.filter(net.getCaffeLayer(_).isInstanceOf[BatchNorm]).map(net.getCaffeLayer(_).asInstanceOf[BatchNorm].update_mean_var = value)
-}
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/systemml/blob/ef1945d7/src/main/scala/org/apache/sysml/api/dl/Utils.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/dl/Utils.scala b/src/main/scala/org/apache/sysml/api/dl/Utils.scala
index 63aaf91..da53f65 100644
--- a/src/main/scala/org/apache/sysml/api/dl/Utils.scala
+++ b/src/main/scala/org/apache/sysml/api/dl/Utils.scala
@@ -39,6 +39,11 @@ import org.apache.sysml.runtime.matrix.data.MatrixBlock
 import org.apache.sysml.api.mlcontext.MLContext
 import org.apache.spark.SparkContext
 import org.apache.spark.api.java.JavaSparkContext
+import org.apache.sysml.parser.ParserWrapper
+import org.apache.sysml.parser.dml.DMLParserWrapper
+import org.apache.sysml.parser.dml.InlineableMethods
+import java.util.ArrayList
+import scala.collection.JavaConverters._
 
 object Utils {
   // ---------------------------------------------------------------------------------------------
@@ -64,6 +69,57 @@ object Utils {
       line = bufReader.readLine()
     }
   }
+  
+  def readDMLScript(fileName:String):String = ParserWrapper.readDMLScript(fileName, Caffe2DML.LOG)
+  val inlineableMethods = new java.util.HashMap[String, java.util.HashMap[String, InlineableMethods]]()
+  def getInlineableMethod(sourceFilePath:String, namespace:String, fnName:String):InlineableMethods = {
+    if(inlineableMethods.contains(namespace))
+      return inlineableMethods.get(namespace).get(fnName)
+    else {
+      val ret = new DMLParserWrapper().getInlineableMethods(sourceFilePath, null, namespace, null)
+      inlineableMethods.put(namespace, ret)
+      return ret.get(fnName)
+    }
+  }
+  
+  def invoke(dir:String, dmlScript: StringBuilder, namespace1: String, returnVariables: List[String], functionName: String, arguments: List[String], appendNewLine: Boolean): Unit = {
+    if(Caffe2DML.INLINE_NN_LIBRARY) {
+      // Caffe2DML.layerDir
+      // For now, donot inline recursively
+      val sourceFileName = if(namespace1.endsWith("::")) namespace1.substring(0, namespace1.length() - 2) else namespace1
+      val method = getInlineableMethod(dir + sourceFileName + ".dml", namespace1, functionName)
+      val generatedDML = method.getInlinedDML(new ArrayList[String](arguments.asJava), new ArrayList[String](returnVariables.asJava))
+      dmlScript.append(generatedDML)
+      dmlScript.append("\n")
+      //System.out.println(generatedDML)
+      return
+    }
+    if (returnVariables.length == 0) throw new DMLRuntimeException("User-defined functions should have atleast one return value")
+    if (returnVariables.length > 1) dmlScript.append("[")
+    dmlScript.append(returnVariables(0))
+    if (returnVariables.length > 1) {
+      for (i <- 1 until returnVariables.length) {
+        dmlScript.append(",").append(returnVariables(i))
+      }
+      dmlScript.append("]")
+    }
+    dmlScript.append(" = ")
+    dmlScript.append(namespace1)
+    dmlScript.append(functionName)
+    dmlScript.append("(")
+    if (arguments != null) {
+      if (arguments.length != 0)
+        dmlScript.append(arguments(0))
+      if (arguments.length > 1) {
+        for (i <- 1 until arguments.length) {
+          dmlScript.append(",").append(arguments(i))
+        }
+      }
+    }
+    dmlScript.append(")")
+    if (appendNewLine)
+      dmlScript.append("\n")
+  }
 
   // ---------------------------------------------------------------------------------------------
   def parseSolver(solverFilePath: String): CaffeSolver = parseSolver(readCaffeSolver(solverFilePath))
@@ -324,4 +380,4 @@ class Utils {
   def saveCaffeModelFile(sc: JavaSparkContext, deployFilePath: String, caffeModelFilePath: String, outputDirectory: String, format: String): Unit =
     Utils.saveCaffeModelFile(sc, deployFilePath, caffeModelFilePath, outputDirectory, format)
 
-}
+}
\ No newline at end of file