You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by ni...@apache.org on 2018/10/11 21:33:21 UTC
systemml git commit: [SYSTEMML-540] Allow user to generate an inlined
DML script in Caffe2DML
Repository: systemml
Updated Branches:
refs/heads/master 11c67055a -> ef1945d70
[SYSTEMML-540] Allow user to generate an inlined DML script in Caffe2DML
- The inlining code is generic enough to be extended to perform parser-level inlining. This commit allows us to compare the tradeoffs of performing script-level inlining v/s hop-level inlining.
- Refactored DMLParserWrapper and also added javadoc.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/ef1945d7
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/ef1945d7
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/ef1945d7
Branch: refs/heads/master
Commit: ef1945d70a85df4f646c315d06a1a094dad6ebb2
Parents: 11c6705
Author: Niketan Pansare <np...@us.ibm.com>
Authored: Thu Oct 11 14:28:59 2018 -0700
Committer: Niketan Pansare <np...@us.ibm.com>
Committed: Thu Oct 11 14:32:35 2018 -0700
----------------------------------------------------------------------
.../parser/common/CustomErrorListener.java | 8 +
.../sysml/parser/dml/DMLParserWrapper.java | 130 ++-
.../java/org/apache/sysml/parser/dml/Dml.g4 | 6 +-
.../sysml/parser/dml/DmlPreprocessor.java | 13 +-
.../apache/sysml/parser/dml/InlineHelper.java | 798 +++++++++++++++++++
.../sysml/parser/dml/InlineableMethods.java | 98 +++
.../controlprogram/caching/CacheableData.java | 9 +-
.../gpu/context/GPUMemoryManager.java | 2 +-
src/main/python/systemml/mllearn/estimators.py | 5 +-
.../org/apache/sysml/api/dl/Caffe2DML.scala | 8 +-
.../org/apache/sysml/api/dl/CaffeLayer.scala | 12 +-
.../org/apache/sysml/api/dl/CaffeSolver.scala | 35 +-
.../org/apache/sysml/api/dl/DMLGenerator.scala | 36 +-
.../scala/org/apache/sysml/api/dl/Utils.scala | 58 +-
14 files changed, 1114 insertions(+), 104 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/systemml/blob/ef1945d7/src/main/java/org/apache/sysml/parser/common/CustomErrorListener.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/common/CustomErrorListener.java b/src/main/java/org/apache/sysml/parser/common/CustomErrorListener.java
index 2af5f69..b82afc9 100644
--- a/src/main/java/org/apache/sysml/parser/common/CustomErrorListener.java
+++ b/src/main/java/org/apache/sysml/parser/common/CustomErrorListener.java
@@ -22,6 +22,7 @@ package org.apache.sysml.parser.common;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
+import java.util.Set;
import org.antlr.v4.runtime.BaseErrorListener;
import org.antlr.v4.runtime.RecognitionException;
@@ -38,6 +39,9 @@ public class CustomErrorListener extends BaseErrorListener {
private boolean atLeastOneError = false;
private boolean atLeastOneWarning = false;
private String currentFileName = null;
+
+ // Names of user internal and external functions definitions
+ public Set<String> functions;
/**
* List of parse issues.
@@ -55,6 +59,10 @@ public class CustomErrorListener extends BaseErrorListener {
public void unsetCurrentFileName() {
currentFileName = null;
}
+
+ public Set<String> getFunctionDefs() {
+ return functions;
+ }
/**
* Validation error occurred. Add the error to the list of parse issues.
http://git-wip-us.apache.org/repos/asf/systemml/blob/ef1945d7/src/main/java/org/apache/sysml/parser/dml/DMLParserWrapper.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/dml/DMLParserWrapper.java b/src/main/java/org/apache/sysml/parser/dml/DMLParserWrapper.java
index 1d7daa1..9b3f8c4 100644
--- a/src/main/java/org/apache/sysml/parser/dml/DMLParserWrapper.java
+++ b/src/main/java/org/apache/sysml/parser/dml/DMLParserWrapper.java
@@ -23,12 +23,14 @@ import java.io.ByteArrayInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
+import java.util.HashMap;
import java.util.Map;
import org.antlr.v4.runtime.ANTLRInputStream;
import org.antlr.v4.runtime.BailErrorStrategy;
import org.antlr.v4.runtime.CommonTokenStream;
import org.antlr.v4.runtime.DefaultErrorStrategy;
+import org.antlr.v4.runtime.TokenStreamRewriter;
import org.antlr.v4.runtime.atn.PredictionMode;
import org.antlr.v4.runtime.misc.ParseCancellationException;
import org.antlr.v4.runtime.tree.ParseTree;
@@ -74,6 +76,15 @@ import org.apache.sysml.parser.dml.DmlParser.StatementContext;
public class DMLParserWrapper extends ParserWrapper
{
private static final Log LOG = LogFactory.getLog(DMLScript.class.getName());
+
+ // Rewriter is only used in getInlineableMethods
+ private TokenStreamRewriter rewriter = null;
+
+ // The below fields are set in the createAST method
+ // Can be null or the path to the DML file
+ private String fileName;
+ // Can be null or the DML script. Note: both fileName and DML script should not be null
+ private String dmlScript;
/**
* Parses the passed file with command line parameters. You can either pass both (local file) or just dmlScript (hdfs) or just file name (import command)
@@ -88,17 +99,72 @@ public class DMLParserWrapper extends ParserWrapper
}
/**
- * This function is supposed to be called directly only from DmlSyntacticValidator when it encounters 'import'
- * @param fileName script file name
- * @param dmlScript script file contents
- * @param sourceNamespace namespace from source statement
- * @param argVals script arguments
- * @return dml program, or null if at least one error
+ * Performs preprocess using DmlPreprocessor listener class.
+ *
+ * @param tree parse tree generated by createAST method
+ * @param errorListener listener that captures potential syntactic errors
+ * @return a parse tree walker to perform further validation
*/
- public DMLProgram doParse(String fileName, String dmlScript, String sourceNamespace, Map<String,String> argVals) {
- DMLProgram dmlPgm = null;
+ ParseTreeWalker preprocess(ParseTree tree, CustomErrorListener errorListener) {
+ ParseTreeWalker walker = new ParseTreeWalker();
+ // Get list of function definitions which take precedence over built-in functions if same name
+ walker.walk(new DmlPreprocessor(errorListener), tree);
+ return walker;
+ }
+
+ /**
+ * Get the inline-able methods
+ *
+ * @param fileName1 can be null or the path to the DML file
+ * @param dmlScript1 can be null or the DML script. Note, both fileName and DML script should not be null.
+ * @param sourceNamespace source namespace
+ * @param argVals command-line arguments
+ * @return hashmap of inline-able methods
+ */
+ public HashMap<String, InlineableMethods> getInlineableMethods(String fileName1, String dmlScript1, String sourceNamespace, Map<String,String> argVals) {
+ // Create AST and do preprocessing
+ CustomErrorListener errorListener = new CustomErrorListener();
+ ParseTree tree = createAST(fileName1, dmlScript1, sourceNamespace, argVals, errorListener, true);
+ ParseTreeWalker walker = preprocess(tree, errorListener);
+
+ // Note: this method uses InlineHelper as a listener to perform rewriting of local variables
+ // It does so in two phases:
+ // Phase 1. Rewriting phase where local variables are rewritten by adding a prefix.
+ // Phase 2. Capture the body of the functions using InlineableMethods class
+
+ // Rewrite all the local variables by adding prefix
+ InlineHelper validator = new InlineHelper(errorListener, argVals, sourceNamespace, errorListener.getFunctionDefs(), rewriter);
+ validator.setPhase(true);
+ walker.walk(validator, tree);
- ANTLRInputStream in;
+ // Use the rewritten text as the new DML script and create AST again
+ fileName = null; dmlScript = rewriter.getText();
+ errorListener = new CustomErrorListener();
+ tree = createAST(fileName, dmlScript, sourceNamespace, argVals, errorListener, true);
+ walker = preprocess(tree, errorListener);
+
+ // Put the content of rewritten function body in the inlineMap
+ validator.setPhase(false);
+ walker.walk(validator, tree);
+
+ return validator.inlineMap;
+ }
+
+ /**
+ * Create an ANTLR parse tree for the input DML script
+ *
+ * @param fileName1 can be null or the path to the DML file
+ * @param dmlScript1 can be null or the DML script. Note, both fileName and DML script should not be null.
+ * @param sourceNamespace source namespace
+ * @param argVals command-line arguments
+ * @param errorListener listener that captures potential syntactic errors
+ * @param performRewriting should perform rewriting of tokens
+ * @return a parse tree
+ */
+ private ParseTree createAST(String fileName1, String dmlScript1, String sourceNamespace, Map<String,String> argVals, CustomErrorListener errorListener, boolean performRewriting) {
+ ANTLRInputStream in = null;
+ this.fileName = fileName1;
+ this.dmlScript = dmlScript1;
try {
if(dmlScript == null) {
dmlScript = readDMLScript(fileName, LOG);
@@ -113,13 +179,13 @@ public class DMLParserWrapper extends ParserWrapper
} catch (LanguageException e) {
throw new ParseException(e.getMessage(), e);
}
-
- ProgramrootContext ast = null;
- CustomErrorListener errorListener = new CustomErrorListener();
+ ProgramrootContext ast = null;
try {
DmlLexer lexer = new DmlLexer(in);
CommonTokenStream tokens = new CommonTokenStream(lexer);
+ if(performRewriting)
+ rewriter = new TokenStreamRewriter(tokens);
DmlParser antlr4Parser = new DmlParser(tokens);
boolean tryOptimizedParsing = false; // For now no optimization, since it is not able to parse integer value.
@@ -163,19 +229,31 @@ public class DMLParserWrapper extends ParserWrapper
catch(Exception e) {
throw new ParseException("ERROR: Cannot parse the program:" + fileName, e);
}
+ return ast;
+ }
+
+
+
+ /**
+ * This function is supposed to be called directly only from DmlSyntacticValidator when it encounters 'import'
+ *
+ * @param fileName1 script file name
+ * @param dmlScript1 script file contents
+ * @param sourceNamespace namespace from source statement
+ * @param argVals script arguments
+ * @return dml program, or null if at least one error
+ */
+ public DMLProgram doParse(String fileName1, String dmlScript1, String sourceNamespace, Map<String,String> argVals) {
+ // Create AST and do preprocessing
+ CustomErrorListener errorListener = new CustomErrorListener();
+ ParseTree tree = createAST(fileName1, dmlScript1, sourceNamespace, argVals, errorListener, false);
+ ParseTreeWalker walker = preprocess(tree, errorListener);
-
- // Now convert the parse tree into DMLProgram
- // Do syntactic validation while converting
- ParseTree tree = ast;
- // And also do syntactic validation
- ParseTreeWalker walker = new ParseTreeWalker();
- // Get list of function definitions which take precedence over built-in functions if same name
- DmlPreprocessor prep = new DmlPreprocessor(errorListener);
- walker.walk(prep, tree);
- // Syntactic validation
- DmlSyntacticValidator validator = new DmlSyntacticValidator(errorListener, argVals, sourceNamespace, prep.getFunctionDefs());
+ // Perform syntactic validation using DmlSyntacticValidator listener
+ DmlSyntacticValidator validator = new DmlSyntacticValidator(errorListener, argVals, sourceNamespace, errorListener.getFunctionDefs());
walker.walk(validator, tree);
+
+ // Check for parse issues and warning
errorListener.unsetCurrentFileName();
this.parseIssues = errorListener.getParseIssues();
this.atLeastOneWarning = errorListener.isAtLeastOneWarning();
@@ -186,9 +264,9 @@ public class DMLParserWrapper extends ParserWrapper
if (atLeastOneWarning) {
LOG.warn(CustomErrorListener.generateParseIssuesMessage(dmlScript, parseIssues));
}
- dmlPgm = createDMLProgram(ast, sourceNamespace);
- return dmlPgm;
+ // Create and return the DML program
+ return createDMLProgram((ProgramrootContext)tree, sourceNamespace);
}
private static DMLProgram createDMLProgram(ProgramrootContext ast, String sourceNamespace) {
@@ -255,4 +333,4 @@ public class DMLParserWrapper extends ParserWrapper
return dmlPgm;
}
-}
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/systemml/blob/ef1945d7/src/main/java/org/apache/sysml/parser/dml/Dml.g4
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/dml/Dml.g4 b/src/main/java/org/apache/sysml/parser/dml/Dml.g4
index 46ee178..08e1288 100644
--- a/src/main/java/org/apache/sysml/parser/dml/Dml.g4
+++ b/src/main/java/org/apache/sysml/parser/dml/Dml.g4
@@ -216,6 +216,6 @@ COMMANDLINE_POSITION_ID: '$' DIGIT+;
STRING: '"' ( ESC | ~[\\"] )*? '"' | '\'' ( ESC | ~[\\'] )*? '\'';
fragment ESC : '\\' [btnfr"'\\] ;
// Comments, whitespaces and new line
-LINE_COMMENT : '#' .*? '\r'? '\n' -> skip ;
-MULTILINE_BLOCK_COMMENT : '/*' .*? '*/' -> skip ;
-WHITESPACE : (' ' | '\t' | '\r' | '\n')+ -> skip ;
+LINE_COMMENT : '#' .*? '\r'? '\n' -> channel(HIDDEN) ;
+MULTILINE_BLOCK_COMMENT : '/*' .*? '*/' -> channel(HIDDEN) ;
+WHITESPACE : (' ' | '\t' | '\r' | '\n')+ -> channel(HIDDEN) ;
http://git-wip-us.apache.org/repos/asf/systemml/blob/ef1945d7/src/main/java/org/apache/sysml/parser/dml/DmlPreprocessor.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/dml/DmlPreprocessor.java b/src/main/java/org/apache/sysml/parser/dml/DmlPreprocessor.java
index 56eb8ca..49e96e7 100644
--- a/src/main/java/org/apache/sysml/parser/dml/DmlPreprocessor.java
+++ b/src/main/java/org/apache/sysml/parser/dml/DmlPreprocessor.java
@@ -20,7 +20,6 @@
package org.apache.sysml.parser.dml;
import java.util.HashSet;
-import java.util.Set;
import org.antlr.v4.runtime.ParserRuleContext;
import org.antlr.v4.runtime.Token;
@@ -84,16 +83,10 @@ import org.apache.sysml.parser.dml.DmlParser.WhileStatementContext;
public class DmlPreprocessor implements DmlListener {
protected final CustomErrorListener errorListener;
- // Names of user internal and external functions definitions
- protected Set<String> functions;
public DmlPreprocessor(CustomErrorListener errorListener) {
this.errorListener = errorListener;
- functions = new HashSet<>();
- }
-
- public Set<String> getFunctionDefs() {
- return functions;
+ this.errorListener.functions = new HashSet<>();
}
@Override
@@ -113,8 +106,8 @@ public class DmlPreprocessor implements DmlListener {
public void exitInternalFunctionDefExpression(InternalFunctionDefExpressionContext ctx) {}
protected void validateFunctionName(String name, ParserRuleContext ctx) {
- if (!functions.contains(name)) {
- functions.add(name);
+ if (!errorListener.functions.contains(name)) {
+ errorListener.functions.add(name);
}
else {
notifyErrorListeners("Function Name Conflict: '" + name + "' already defined in " + errorListener.getCurrentFileName(), ctx.start);
http://git-wip-us.apache.org/repos/asf/systemml/blob/ef1945d7/src/main/java/org/apache/sysml/parser/dml/InlineHelper.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/dml/InlineHelper.java b/src/main/java/org/apache/sysml/parser/dml/InlineHelper.java
new file mode 100644
index 0000000..34d886c
--- /dev/null
+++ b/src/main/java/org/apache/sysml/parser/dml/InlineHelper.java
@@ -0,0 +1,798 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysml.parser.dml;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Random;
+import java.util.Set;
+
+import org.antlr.v4.runtime.ParserRuleContext;
+import org.antlr.v4.runtime.Token;
+import org.antlr.v4.runtime.TokenStreamRewriter;
+import org.antlr.v4.runtime.tree.ErrorNode;
+import org.antlr.v4.runtime.tree.TerminalNode;
+import org.apache.sysml.parser.Expression;
+import org.apache.sysml.parser.ParameterExpression;
+import org.apache.sysml.parser.common.CommonSyntacticValidator;
+import org.apache.sysml.parser.common.CustomErrorListener;
+import org.apache.sysml.parser.dml.DmlParser.AccumulatorAssignmentStatementContext;
+import org.apache.sysml.parser.dml.DmlParser.AddSubExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.AssignmentStatementContext;
+import org.apache.sysml.parser.dml.DmlParser.AtomicExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.BooleanAndExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.BooleanNotExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.BooleanOrExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.BuiltinFunctionExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.CommandlineParamExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.CommandlinePositionExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.ConstDoubleIdExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.ConstFalseExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.ConstIntIdExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.ConstStringIdExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.ConstTrueExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.DataIdExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.ExternalFunctionDefExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.ForStatementContext;
+import org.apache.sysml.parser.dml.DmlParser.FunctionCallAssignmentStatementContext;
+import org.apache.sysml.parser.dml.DmlParser.FunctionCallMultiAssignmentStatementContext;
+import org.apache.sysml.parser.dml.DmlParser.IfStatementContext;
+import org.apache.sysml.parser.dml.DmlParser.IfdefAssignmentStatementContext;
+import org.apache.sysml.parser.dml.DmlParser.ImportStatementContext;
+import org.apache.sysml.parser.dml.DmlParser.IndexedExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.InternalFunctionDefExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.IterablePredicateColonExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.IterablePredicateSeqExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.MatrixDataTypeCheckContext;
+import org.apache.sysml.parser.dml.DmlParser.MatrixMulExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.Ml_typeContext;
+import org.apache.sysml.parser.dml.DmlParser.ModIntDivExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.MultDivExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.MultiIdExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.ParForStatementContext;
+import org.apache.sysml.parser.dml.DmlParser.ParameterizedExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.PathStatementContext;
+import org.apache.sysml.parser.dml.DmlParser.PowerExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.ProgramrootContext;
+import org.apache.sysml.parser.dml.DmlParser.RelationalExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.SimpleDataIdentifierExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.StatementContext;
+import org.apache.sysml.parser.dml.DmlParser.StrictParameterizedExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.StrictParameterizedKeyValueStringContext;
+import org.apache.sysml.parser.dml.DmlParser.TypedArgAssignContext;
+import org.apache.sysml.parser.dml.DmlParser.TypedArgNoAssignContext;
+import org.apache.sysml.parser.dml.DmlParser.UnaryExpressionContext;
+import org.apache.sysml.parser.dml.DmlParser.ValueTypeContext;
+import org.apache.sysml.parser.dml.DmlParser.WhileStatementContext;
+
+/**
+ * This class is used to generate inline-able methods.
+ * It does so in two phases:
+ * - Phase 1. Rewriting phase where local variables are rewritten by adding a prefix.
+ * - Phase 2. Capture the body of the functions using InlineableMethods class
+ */
+public class InlineHelper extends CommonSyntacticValidator implements DmlListener {
+ final static String ARG_PREFIX;
+ static {
+ Random rand = new Random();
+ ARG_PREFIX = "INTERNAL_PREFIX_" + Math.abs(rand.nextLong()) + "_" + Math.abs(rand.nextLong()) + "_";
+ }
+ public HashMap<String, InlineableMethods> inlineMap = new HashMap<>();
+ TokenStreamRewriter rewriter;
+
+ // Set internally
+ HashSet<String> variables = new HashSet<>();
+ String currentFunction = null;
+ boolean isRewritePhase;
+
+ public InlineHelper(CustomErrorListener errorListener, Map<String, String> argVals, String sourceNamespace,
+ Set<String> prepFunctions, TokenStreamRewriter rewriter1) {
+ super(errorListener, argVals, sourceNamespace, prepFunctions);
+ rewriter = rewriter1;
+ }
+
+ void setPhase(boolean isRewritePhase1) {
+ isRewritePhase = isRewritePhase1;
+ }
+
+
+ @Override
+ public void enterInternalFunctionDefExpression(InternalFunctionDefExpressionContext ctx) {
+ currentFunction = ctx.name.getText();
+ variables.clear();
+ }
+
+ @Override
+ public void exitInternalFunctionDefExpression(InternalFunctionDefExpressionContext ctx) {
+ if(!isRewritePhase) {
+ StringBuilder sb = new StringBuilder();
+ for(StatementContext stmt : ctx.body) {
+ sb.append(stmt.getText());
+ sb.append("\n");
+ }
+ ArrayList<String> inputArgs = new ArrayList<>();
+ for(TypedArgAssignContext in : ctx.inputParams) {
+ inputArgs.add(ARG_PREFIX + in.paramName.getText());
+ }
+ ArrayList<String> retVariables = new ArrayList<>();
+ for(TypedArgNoAssignContext out : ctx.outputParams) {
+ retVariables.add(ARG_PREFIX + out.paramName.getText());
+ }
+
+ inlineMap.put(currentFunction, new InlineableMethods(currentFunction, sb.toString(), variables, inputArgs, retVariables));
+ }
+ currentFunction = null;
+ variables.clear();
+ }
+
+ @Override
+ public void enterIndexedExpression(IndexedExpressionContext ctx) {
+ if(currentFunction != null && isRewritePhase) {
+ rewriter.insertBefore(ctx.name, " " + ARG_PREFIX);
+ }
+ }
+
+ @Override
+ public void exitIndexedExpression(IndexedExpressionContext ctx) {
+ if(currentFunction != null)
+ variables.add(ctx.name.getText());
+ }
+
+ @Override
+ public void enterSimpleDataIdentifierExpression(SimpleDataIdentifierExpressionContext ctx) {
+ if(currentFunction != null && isRewritePhase) {
+ rewriter.insertBefore(ctx.start, " " + ARG_PREFIX);
+ rewriter.insertAfter(ctx.stop, " ");
+ }
+ }
+
+ @Override
+ public void exitSimpleDataIdentifierExpression(SimpleDataIdentifierExpressionContext ctx) {
+ if(currentFunction != null)
+ variables.add(ctx.getText());
+ }
+
+ @Override
+ public void enterForStatement(ForStatementContext ctx) {
+ if(currentFunction != null && isRewritePhase) {
+ rewriter.insertBefore(ctx.iterVar, " " + ARG_PREFIX);
+ rewriter.insertAfter(ctx.iterVar, " ");
+ }
+ }
+
+ @Override
+ public void enterParForStatement(ParForStatementContext ctx) {
+ if(currentFunction != null && isRewritePhase) {
+ rewriter.insertBefore(ctx.iterVar, " " + ARG_PREFIX);
+ rewriter.insertAfter(ctx.iterVar, " ");
+ }
+ }
+
+ @Override
+ public void exitForStatement(ForStatementContext ctx) {
+ if(currentFunction != null)
+ variables.add(ctx.iterVar.getText());
+ if(currentFunction != null && isRewritePhase) {
+ if(ctx.body != null && ctx.body.size() > 0)
+ rewriter.insertBefore(ctx.body.get(0).start, "\n");
+ rewriter.insertAfter(ctx.stop, "\n");
+ }
+ }
+
+ @Override
+ public void exitParForStatement(ParForStatementContext ctx) {
+ if(currentFunction != null)
+ variables.add(ctx.iterVar.getText());
+ if(currentFunction != null && isRewritePhase) {
+ if(ctx.body != null && ctx.body.size() > 0)
+ rewriter.insertBefore(ctx.body.get(0).start, "\n");
+ rewriter.insertAfter(ctx.stop, "\n");
+ }
+ }
+
+
+ @Override
+ protected ConvertedDMLSyntax convertToDMLSyntax(ParserRuleContext ctx, String namespace, String functionName,
+ ArrayList<ParameterExpression> paramExpression, Token fnName) {
+
+ return null;
+ }
+
+ @Override
+ public void enterAccumulatorAssignmentStatement(AccumulatorAssignmentStatementContext ctx) {
+
+
+ }
+
+ @Override
+ public void enterAddSubExpression(AddSubExpressionContext ctx) {
+
+
+ }
+
+ @Override
+ public void enterAssignmentStatement(AssignmentStatementContext ctx) {
+
+
+ }
+
+ @Override
+ public void enterAtomicExpression(AtomicExpressionContext ctx) {
+
+
+ }
+
+ @Override
+ public void enterBooleanAndExpression(BooleanAndExpressionContext ctx) {
+
+
+ }
+
+ @Override
+ public void enterBooleanNotExpression(BooleanNotExpressionContext ctx) {
+
+
+ }
+
+ @Override
+ public void enterBooleanOrExpression(BooleanOrExpressionContext ctx) {
+
+
+ }
+
+ @Override
+ public void enterBuiltinFunctionExpression(BuiltinFunctionExpressionContext ctx) {
+
+
+ }
+
+ @Override
+ public void enterCommandlineParamExpression(CommandlineParamExpressionContext ctx) {
+
+
+ }
+
+ @Override
+ public void enterCommandlinePositionExpression(CommandlinePositionExpressionContext ctx) {
+
+
+ }
+
+ @Override
+ public void enterConstDoubleIdExpression(ConstDoubleIdExpressionContext ctx) {
+
+
+ }
+
+ @Override
+ public void enterConstFalseExpression(ConstFalseExpressionContext ctx) {
+
+
+ }
+
+ @Override
+ public void enterConstIntIdExpression(ConstIntIdExpressionContext ctx) {
+
+
+ }
+
+ @Override
+ public void enterConstStringIdExpression(ConstStringIdExpressionContext ctx) {
+
+
+ }
+
+ @Override
+ public void enterConstTrueExpression(ConstTrueExpressionContext ctx) {
+
+
+ }
+
+ @Override
+ public void enterDataIdExpression(DataIdExpressionContext ctx) {
+
+
+ }
+
+ @Override
+ public void enterEveryRule(ParserRuleContext arg0) {
+
+
+ }
+
+ @Override
+ public void enterExternalFunctionDefExpression(ExternalFunctionDefExpressionContext ctx) {
+
+
+ }
+
+
+ @Override
+ public void enterFunctionCallAssignmentStatement(FunctionCallAssignmentStatementContext ctx) {
+
+
+ }
+
+ @Override
+ public void enterFunctionCallMultiAssignmentStatement(FunctionCallMultiAssignmentStatementContext ctx) {
+
+
+ }
+
+ @Override
+ public void enterIfdefAssignmentStatement(IfdefAssignmentStatementContext ctx) {
+
+
+ }
+
+ @Override
+ public void enterIfStatement(IfStatementContext ctx) {
+
+
+ }
+
+ @Override
+ public void enterImportStatement(ImportStatementContext ctx) {
+
+
+ }
+
+ @Override
+ public void enterIterablePredicateColonExpression(IterablePredicateColonExpressionContext ctx) {
+
+
+ }
+
+ @Override
+ public void enterIterablePredicateSeqExpression(IterablePredicateSeqExpressionContext ctx) {
+
+
+ }
+
+ @Override
+ public void enterMatrixDataTypeCheck(MatrixDataTypeCheckContext ctx) {
+
+
+ }
+
+ @Override
+ public void enterMatrixMulExpression(MatrixMulExpressionContext ctx) {
+
+
+ }
+
+ @Override
+ public void enterMl_type(Ml_typeContext ctx) {
+
+
+ }
+
+ @Override
+ public void enterModIntDivExpression(ModIntDivExpressionContext ctx) {
+
+
+ }
+
+ @Override
+ public void enterMultDivExpression(MultDivExpressionContext ctx) {
+
+
+ }
+
+ @Override
+ public void enterMultiIdExpression(MultiIdExpressionContext ctx) {
+
+
+ }
+
+ @Override
+ public void enterParameterizedExpression(ParameterizedExpressionContext ctx) {
+
+
+ }
+
+ @Override
+ public void enterPathStatement(PathStatementContext ctx) {
+
+
+ }
+
+ @Override
+ public void enterPowerExpression(PowerExpressionContext ctx) {
+
+
+ }
+
+ @Override
+ public void enterProgramroot(ProgramrootContext ctx) {
+
+
+ }
+
+ @Override
+ public void enterRelationalExpression(RelationalExpressionContext ctx) {
+
+
+ }
+
+ @Override
+ public void enterStrictParameterizedExpression(StrictParameterizedExpressionContext ctx) {
+
+
+ }
+
+ @Override
+ public void enterStrictParameterizedKeyValueString(StrictParameterizedKeyValueStringContext ctx) {
+
+
+ }
+
+ @Override
+ public void enterTypedArgAssign(TypedArgAssignContext ctx) {
+
+
+ }
+
+ @Override
+ public void enterTypedArgNoAssign(TypedArgNoAssignContext ctx) {
+
+
+ }
+
+ @Override
+ public void enterUnaryExpression(UnaryExpressionContext ctx) {
+
+
+ }
+
+ @Override
+ public void enterValueType(ValueTypeContext ctx) {
+
+
+ }
+
+ @Override
+ public void enterWhileStatement(WhileStatementContext ctx) {
+
+
+ }
+
+ @Override
+ public void exitAccumulatorAssignmentStatement(AccumulatorAssignmentStatementContext ctx) {
+ if(currentFunction != null && isRewritePhase) {
+ rewriter.insertAfter(ctx.stop, ";\n");
+ }
+
+ }
+
+ @Override
+ public void exitAddSubExpression(AddSubExpressionContext ctx) {
+
+
+ }
+
+ @Override
+ public void exitAssignmentStatement(AssignmentStatementContext ctx) {
+ if(currentFunction != null && isRewritePhase) {
+ rewriter.insertAfter(ctx.stop, ";\n");
+ }
+
+ }
+
+ @Override
+ public void exitAtomicExpression(AtomicExpressionContext ctx) {
+
+
+ }
+
+ @Override
+ public void exitBooleanAndExpression(BooleanAndExpressionContext ctx) {
+
+
+ }
+
+ @Override
+ public void exitBooleanNotExpression(BooleanNotExpressionContext ctx) {
+
+
+ }
+
+ @Override
+ public void exitBooleanOrExpression(BooleanOrExpressionContext ctx) {
+
+
+ }
+
+ @Override
+ public void exitBuiltinFunctionExpression(BuiltinFunctionExpressionContext ctx) {
+
+
+ }
+
+ @Override
+ public void exitCommandlineParamExpression(CommandlineParamExpressionContext ctx) {
+
+
+ }
+
+ @Override
+ public void exitCommandlinePositionExpression(CommandlinePositionExpressionContext ctx) {
+
+
+ }
+
+ @Override
+ public void exitConstDoubleIdExpression(ConstDoubleIdExpressionContext ctx) {
+
+
+ }
+
+ @Override
+ public void exitConstFalseExpression(ConstFalseExpressionContext ctx) {
+
+
+ }
+
+ @Override
+ public void exitConstIntIdExpression(ConstIntIdExpressionContext ctx) {
+
+
+ }
+
+ @Override
+ public void exitConstStringIdExpression(ConstStringIdExpressionContext ctx) {
+
+
+ }
+
+ @Override
+ public void exitConstTrueExpression(ConstTrueExpressionContext ctx) {
+
+
+ }
+
+ @Override
+ public void exitDataIdExpression(DataIdExpressionContext ctx) {
+
+
+ }
+
+ @Override
+ public void exitEveryRule(ParserRuleContext arg0) {
+
+
+ }
+
+ @Override
+ public void exitExternalFunctionDefExpression(ExternalFunctionDefExpressionContext ctx) {
+
+
+ }
+
+
+
+ @Override
+ public void exitFunctionCallAssignmentStatement(FunctionCallAssignmentStatementContext ctx) {
+ if(currentFunction != null && isRewritePhase) {
+ rewriter.insertAfter(ctx.stop, ";\n");
+ }
+
+ }
+
+ @Override
+ public void exitFunctionCallMultiAssignmentStatement(FunctionCallMultiAssignmentStatementContext ctx) {
+ if(currentFunction != null && isRewritePhase) {
+ rewriter.insertAfter(ctx.stop, ";\n");
+ }
+
+ }
+
+ @Override
+ public void exitIfdefAssignmentStatement(IfdefAssignmentStatementContext ctx) {
+ if(currentFunction != null && isRewritePhase) {
+ rewriter.insertAfter(ctx.stop, ";\n");
+ }
+
+ }
+
+ @Override
+ public void exitIfStatement(IfStatementContext ctx) {
+ if(currentFunction != null && isRewritePhase) {
+ if(ctx.ifBody != null && ctx.ifBody.size() > 0)
+ rewriter.insertBefore(ctx.ifBody.get(0).start, "\n");
+ if(ctx.elseBody != null && ctx.elseBody.size() > 0)
+ rewriter.insertBefore(ctx.elseBody.get(0).start, "\n");
+ rewriter.insertAfter(ctx.stop, "\n");
+ }
+
+ }
+
+ @Override
+ public void exitImportStatement(ImportStatementContext ctx) {
+ if(currentFunction != null && isRewritePhase) {
+ rewriter.insertAfter(ctx.stop, ";\n");
+ }
+
+ }
+
+ @Override
+ public void exitIterablePredicateColonExpression(IterablePredicateColonExpressionContext ctx) {
+
+
+ }
+
+ @Override
+ public void exitIterablePredicateSeqExpression(IterablePredicateSeqExpressionContext ctx) {
+
+
+ }
+
+ @Override
+ public void exitMatrixDataTypeCheck(MatrixDataTypeCheckContext ctx) {
+
+
+ }
+
+ @Override
+ public void exitMatrixMulExpression(MatrixMulExpressionContext ctx) {
+
+
+ }
+
+ @Override
+ public void exitMl_type(Ml_typeContext ctx) {
+
+
+ }
+
+ @Override
+ public void exitModIntDivExpression(ModIntDivExpressionContext ctx) {
+
+
+ }
+
+ @Override
+ public void exitMultDivExpression(MultDivExpressionContext ctx) {
+
+
+ }
+
+ @Override
+ public void exitMultiIdExpression(MultiIdExpressionContext ctx) {
+
+
+ }
+
+ @Override
+ public void exitParameterizedExpression(ParameterizedExpressionContext ctx) {
+
+
+ }
+
+ @Override
+ public void exitPathStatement(PathStatementContext ctx) {
+ if(currentFunction != null && isRewritePhase) {
+ rewriter.insertAfter(ctx.stop, ";\n");
+ }
+
+ }
+
+ @Override
+ public void exitPowerExpression(PowerExpressionContext ctx) {
+
+
+ }
+
+ @Override
+ public void exitProgramroot(ProgramrootContext ctx) {
+
+
+ }
+
+ @Override
+ public void exitRelationalExpression(RelationalExpressionContext ctx) {
+
+
+ }
+
+ @Override
+ public void exitStrictParameterizedExpression(StrictParameterizedExpressionContext ctx) {
+
+
+ }
+
+ @Override
+ public void exitStrictParameterizedKeyValueString(StrictParameterizedKeyValueStringContext ctx) {
+
+
+ }
+
+ @Override
+ public void exitTypedArgAssign(TypedArgAssignContext ctx) {
+
+
+ }
+
+ @Override
+ public void exitTypedArgNoAssign(TypedArgNoAssignContext ctx) {
+
+
+ }
+
+ @Override
+ public void exitUnaryExpression(UnaryExpressionContext ctx) {
+
+
+ }
+
+ @Override
+ public void exitValueType(ValueTypeContext ctx) {
+
+
+ }
+
+ @Override
+ public void exitWhileStatement(WhileStatementContext ctx) {
+ if(currentFunction != null && isRewritePhase) {
+ if(ctx.body != null && ctx.body.size() > 0)
+ rewriter.insertBefore(ctx.body.get(0).start, "\n");
+ rewriter.insertAfter(ctx.stop, "\n");
+ }
+ }
+
+ @Override
+ public String falseStringLiteral() {
+
+ return null;
+ }
+
+ @Override
+ protected Expression handleLanguageSpecificFunction(ParserRuleContext ctx, String functionName,
+ ArrayList<ParameterExpression> paramExpressions) {
+
+ return null;
+ }
+
+ @Override
+ public String namespaceResolutionOp() {
+
+ return null;
+ }
+
+ @Override
+ public String trueStringLiteral() {
+
+ return null;
+ }
+
+ @Override
+ public void visitErrorNode(ErrorNode arg0) {
+
+
+ }
+
+ @Override
+ public void visitTerminal(TerminalNode arg0) {
+
+
+ }
+
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/systemml/blob/ef1945d7/src/main/java/org/apache/sysml/parser/dml/InlineableMethods.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/dml/InlineableMethods.java b/src/main/java/org/apache/sysml/parser/dml/InlineableMethods.java
new file mode 100644
index 0000000..d3b0d11
--- /dev/null
+++ b/src/main/java/org/apache/sysml/parser/dml/InlineableMethods.java
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysml.parser.dml;
+
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Random;
+
+import org.apache.sysml.runtime.DMLRuntimeException;
+
+/**
+ * This class is a simple container class used to hold the function to be inlined.
+ * It contains the function name, body and also the input and return arguments.
+ * The user invokes getInlinedDML method to get the inlined DML code.
+ */
+public class InlineableMethods {
+ ArrayList<String> _variables;
+ final String _body;
+ final String _fnName;
+ final ArrayList<String> _inputArgs;
+ final ArrayList<String> _retVariables;
+ static int CALLER_ID = 1;
+
+ public InlineableMethods(String fnName, String body, HashSet<String> variables, ArrayList<String> inputArgs, ArrayList<String> retVariables) {
+ _fnName = fnName;
+ _body = body;
+ _variables = new ArrayList<String>(variables);
+ _variables.sort(Comparator.comparing(String::length).reversed());
+ _inputArgs = inputArgs;
+ _retVariables = retVariables;
+ }
+
+ public ArrayList<String> getLocalVariables() {
+ return _variables;
+ }
+
+ private String _getInlinedDML(HashMap<String, String> actualArguments) {
+ String ret = _body;
+ int callerID = CALLER_ID++;
+ for(String var : _variables) {
+ String originalVarName = var.substring(InlineHelper.ARG_PREFIX.length());
+ if(actualArguments.containsKey(var)) {
+ ret = ret.replaceAll(var, actualArguments.get(var));
+ }
+ else {
+ // internal argument
+ ret = ret.replaceAll(var, LOCAL_ARG_PREFIX + _fnName + "_" + callerID + "_" + originalVarName);
+ }
+ }
+ return ret;
+ }
+
+ public String getInlinedDML(ArrayList<String> actualInputArgs, ArrayList<String> actualRetVariables) {
+ HashMap<String, String> actualArguments = new HashMap<>();
+ if(actualInputArgs.size() != _inputArgs.size()) {
+ throw new DMLRuntimeException("Incorrect number of input arguments for the function " + _fnName + ": expected "
+ + _inputArgs.size() + " (" + String.join(", ", _inputArgs) + ") but found " + actualInputArgs.size()
+ + " (" + String.join(", ", actualInputArgs) + ")");
+ }
+ if(actualRetVariables.size() != _retVariables.size()) {
+ throw new DMLRuntimeException("Incorrect number of return variables for the function " + _fnName + ": expected "
+ + _retVariables.size() + " (" + String.join(", ", _retVariables) + ") but found " + actualRetVariables.size()
+ + " (" + String.join(", ", actualRetVariables) + ")");
+ }
+ for(int i = 0; i < _inputArgs.size(); i++) {
+ actualArguments.put(_inputArgs.get(i), actualInputArgs.get(i));
+ }
+ for(int i = 0; i < _retVariables.size(); i++) {
+ actualArguments.put(_retVariables.get(i), actualRetVariables.get(i));
+ }
+ return _getInlinedDML(actualArguments);
+ }
+
+ static final String LOCAL_ARG_PREFIX;
+ static {
+ Random rand = new Random();
+ LOCAL_ARG_PREFIX = "LOCAL_" + Math.abs(rand.nextLong()) + "_" + Math.abs(rand.nextLong());
+// LOCAL_ARG_PREFIX = "LOCAL_";
+ }
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/systemml/blob/ef1945d7/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java b/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java
index 03bc3b3..15dd23e 100644
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java
@@ -307,7 +307,14 @@ public abstract class CacheableData<T extends CacheBlock> extends Data
}
public MatrixCharacteristics getMatrixCharacteristics() {
- return _metaData.getMatrixCharacteristics();
+ MatrixCharacteristics mc = _metaData.getMatrixCharacteristics();
+ if(mc.getRowsPerBlock() == -1) {
+ mc.setRowsPerBlock(OptimizerUtils.DEFAULT_BLOCKSIZE);
+ }
+ if(mc.getColsPerBlock() == -1) {
+ mc.setColsPerBlock(OptimizerUtils.DEFAULT_BLOCKSIZE);
+ }
+ return mc;
}
public abstract void refreshMetaData();
http://git-wip-us.apache.org/repos/asf/systemml/blob/ef1945d7/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.java
index 6772b4a..a08d4fd 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.java
@@ -298,7 +298,7 @@ public class GPUMemoryManager {
evictOrClear(sizeBasedUnlockedGPUObjects.get(), opcode);
A = cudaMallocNoWarn(tmpA, size, null);
if(A == null)
- LOG.warn("cudaMalloc failed after clearing/evicting based on size.");
+ LOG.debug("cudaMalloc failed after clearing/evicting based on size.");
if(ConfigurationManager.isStatistics()) {
long totalTime = System.nanoTime() - t0;
GPUStatistics.cudaEvictTime.add(totalTime);
http://git-wip-us.apache.org/repos/asf/systemml/blob/ef1945d7/src/main/python/systemml/mllearn/estimators.py
----------------------------------------------------------------------
diff --git a/src/main/python/systemml/mllearn/estimators.py b/src/main/python/systemml/mllearn/estimators.py
index fbcd3e2..8a100b4 100644
--- a/src/main/python/systemml/mllearn/estimators.py
+++ b/src/main/python/systemml/mllearn/estimators.py
@@ -924,7 +924,7 @@ class Caffe2DML(BaseSystemMLClassifier):
self.estimator.setWeightsToIgnore(ignore_weights)
def set(self, debug=None, train_algo=None, test_algo=None, parallel_batches=None,
- output_activations=None, perform_one_hot_encoding=None, parfor_parameters=None):
+ output_activations=None, perform_one_hot_encoding=None, parfor_parameters=None, inline_nn_library=None):
"""
Set input to Caffe2DML
@@ -937,9 +937,12 @@ class Caffe2DML(BaseSystemMLClassifier):
output_activations: (developer flag) directory to output activations of each layer as csv while prediction. To be used only in batch mode (default: None)
perform_one_hot_encoding: should perform one-hot encoding in DML using table function (default: False)
parfor_parameters: dictionary for parfor parameters when using allreduce-style algorithms (default: "")
+ inline_nn_library: whether to inline the NN library when generating DML using Caffe2DML (default: False)
"""
if debug is not None:
self.estimator.setInput("$debug", str(debug).upper())
+ if inline_nn_library is not None:
+ self.estimator.setInput("$inline_nn_library", str(inline_nn_library).upper())
if train_algo is not None:
self.estimator.setInput("$train_algo", str(train_algo).lower())
if test_algo is not None:
http://git-wip-us.apache.org/repos/asf/systemml/blob/ef1945d7/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala b/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala
index 26e554f..8ddb1fe 100644
--- a/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala
+++ b/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala
@@ -119,8 +119,9 @@ object Caffe2DML {
val LOG = LogFactory.getLog(classOf[Caffe2DML].getName())
// ------------------------------------------------------------------------
val USE_PLUS_EQ = true
- def layerDir = "nn/layers/"
- def optimDir = "nn/optim/"
+ def nnDir = "nn/"
+ def layerDir = nnDir + "layers/"
+ def optimDir = nnDir + "optim/"
// Naming conventions:
val X = "X"; val y = "y"; val batchSize = "BATCH_SIZE"; val numImages = "num_images"; val numValidationImages = "num_validation"
@@ -159,6 +160,7 @@ object Caffe2DML {
val BATCH_ALGORITHM = "batch"
val ALLREDUCE_ALGORITHM = "allreduce"
val ALLREDUCE_PARALLEL_BATCHES_ALGORITHM = "allreduce_parallel_batches"
+ var INLINE_NN_LIBRARY = false
}
class Caffe2DML(val sc: SparkContext,
@@ -312,6 +314,7 @@ class Caffe2DML(val sc: SparkContext,
// Flags passed by user
val DEBUG_TRAINING = if (inputs.containsKey("$debug")) inputs.get("$debug").toLowerCase.toBoolean else false
+ Caffe2DML.INLINE_NN_LIBRARY = if (inputs.containsKey("$inline_nn_library")) inputs.get("$inline_nn_library").toLowerCase.toBoolean else false
assign(tabDMLScript, "debug", if (DEBUG_TRAINING) "TRUE" else "FALSE")
setDebugFlags(DEBUG_TRAINING)
@@ -721,6 +724,7 @@ class Caffe2DMLModel(val numClasses: String, val sc: SparkContext, val solver: C
reset // Reset the state of DML generator for training script.
val DEBUG_PREDICTION = if (estimator.inputs.containsKey("$debug")) estimator.inputs.get("$debug").toLowerCase.toBoolean else false
+ Caffe2DML.INLINE_NN_LIBRARY = if (estimator.inputs.containsKey("$inline_nn_library")) estimator.inputs.get("$inline_nn_library").toLowerCase.toBoolean else false
assign(tabDMLScript, "debug", if (DEBUG_PREDICTION) "TRUE" else "FALSE")
estimator.setDebugFlags(DEBUG_PREDICTION)
http://git-wip-us.apache.org/repos/asf/systemml/blob/ef1945d7/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala b/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala
index d664f6e..b290983 100644
--- a/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala
+++ b/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala
@@ -29,6 +29,7 @@ import caffe.Caffe.EltwiseParameter.EltwiseOp
import org.apache.sysml.runtime.DMLRuntimeException;
import java.util.ArrayList
import caffe.Caffe.PoolingParameter.PoolMethod
+import scala.collection.JavaConverters._
trait CaffeLayer extends BaseDMLGenerator {
// -------------------------------------------------
@@ -125,7 +126,7 @@ trait CaffeLayer extends BaseDMLGenerator {
// The layers that have a corresponding dml script call this method.
// Assumption: the first variable of resultVariables is always dX
def invokeBackward(dmlScript: StringBuilder, outSuffix: String, resultVariables: List[String], arguments: String*): Unit = {
- invoke(dmlScript, sourceFileName + "::", resultVariables.map(_ + outSuffix), "backward", arguments.toList, false)
+ Utils.invoke(Caffe2DML.layerDir, dmlScript, sourceFileName + "::", resultVariables.map(_ + outSuffix), "backward", arguments.toList, false)
val bottomLayerIDs = net.getBottomLayers(param.getName).map(l => net.getCaffeLayer(l).id)
dmlScript.append("; ")
bottomLayerIDs.map(bottomLayerID => dmlScript.append(dX(bottomLayerID) + outSuffix + " = " + resultVariables(0) + outSuffix + "; "))
@@ -140,6 +141,13 @@ trait CaffeLayer extends BaseDMLGenerator {
dmlScript.append("\n")
}
// --------------------------------------------------------------------------------------
+
+ def invoke(dmlScript: StringBuilder, namespace1: String, returnVariables: List[String], functionName: String, arguments: List[String]): Unit =
+ Utils.invoke(Caffe2DML.layerDir, dmlScript, namespace1, returnVariables, functionName, arguments, true)
+ def invoke(dmlScript: StringBuilder, namespace1: String, returnVariables: List[String], functionName: String, appendNewLine: Boolean, arguments: String*): Unit =
+ Utils.invoke(Caffe2DML.layerDir, dmlScript, namespace1, returnVariables, functionName, arguments.toList, appendNewLine)
+ def invoke(dmlScript: StringBuilder, namespace1: String, returnVariables: List[String], functionName: String, arguments: String*): Unit =
+ Utils.invoke(Caffe2DML.layerDir, dmlScript, namespace1, returnVariables, functionName, arguments.toList, true)
}
trait IsLossLayer extends CaffeLayer {
@@ -1603,4 +1611,4 @@ class DeConvolution(val param: LayerParameter, val id: Int, val net: CaffeNetwor
if (convParam.hasPadW) convParam.getPadW.toString
else if (convParam.getPadCount > 0) convParam.getPad(0).toString
else "0"
-}
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/systemml/blob/ef1945d7/src/main/scala/org/apache/sysml/api/dl/CaffeSolver.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/dl/CaffeSolver.scala b/src/main/scala/org/apache/sysml/api/dl/CaffeSolver.scala
index 8559c60..da963b9 100644
--- a/src/main/scala/org/apache/sysml/api/dl/CaffeSolver.scala
+++ b/src/main/scala/org/apache/sysml/api/dl/CaffeSolver.scala
@@ -67,10 +67,12 @@ trait CaffeSolver {
val hasDecayMult = layer.param.getParamList != null && layer.param.getParamList.size >= 1 && layer.param.getParamList.get(0).hasDecayMult
val newLambda = if(hasDecayMult) layer.param.getParamList.get(0).getDecayMult * lambda else lambda
- dmlScript.append("\t").append(layer.dWeight + "_reg = " + regularizationSource + "::backward(" + layer.weight + ", " + newLambda + ")\n")
+ Utils.invoke(Caffe2DML.layerDir, dmlScript, regularizationSource + "::", List[String](layer.dWeight + "_reg"), "backward",
+ List[String](layer.weight, "" + newLambda), true)
dmlScript.append("\t").append(layer.dWeight + " = " + layer.dWeight + " + " + layer.dWeight + "_reg\n")
if(layer.shouldUpdateExtraWeight) {
- dmlScript.append("\t").append(layer.dExtraWeight + "_reg = " + regularizationSource + "::backward(" + layer.extraWeight + ", " + newLambda + ")\n")
+ Utils.invoke(Caffe2DML.layerDir, dmlScript, regularizationSource + "::", List[String](layer.dExtraWeight + "_reg"), "backward",
+ List[String](layer.extraWeight, "" + newLambda), true)
dmlScript.append("\t").append(layer.dExtraWeight + " = " + layer.dExtraWeight + " + " + layer.dExtraWeight + "_reg\n")
}
}
@@ -339,32 +341,19 @@ class Nesterov(regularizationType:String = "L2", lambda: Double = 5e-04, momentu
* input v.
*/
def update(dmlScript: StringBuilder, layer: CaffeLayer): Unit = {
- val fn = if (Caffe2DML.USE_NESTEROV_UDF) "update_nesterov" else "sgd_nesterov::update"
- val lastParameter = if (Caffe2DML.USE_NESTEROV_UDF) (", " + lambda) else ""
+
if (!Caffe2DML.USE_NESTEROV_UDF) {
regularization_update(regularizationType, lambda, dmlScript, layer)
}
if (layer.shouldUpdateWeight)
- dmlScript
- .append("\t")
- .append(
- "[" + commaSep(layer.weight, layer.weight + "_v") + "] " +
- "= " + fn + "(" + commaSep(layer.weight, layer.dWeight, getWeightLr(layer), momentum.toString, layer.weight + "_v") + lastParameter + ")\n"
- )
+ Utils.invoke(Caffe2DML.optimDir, dmlScript, "sgd_nesterov::", List[String](layer.weight, layer.weight + "_v"), "update",
+ List[String](layer.weight, layer.dWeight, getWeightLr(layer), momentum.toString, layer.weight + "_v"), true)
if (layer.shouldUpdateExtraWeight)
- dmlScript
- .append("\t")
- .append(
- "[" + commaSep(layer.extraWeight, layer.extraWeight + "_v") + "] " +
- "= " + fn + "(" + commaSep(layer.extraWeight, layer.dExtraWeight, getWeightLr(layer), momentum.toString, layer.extraWeight + "_v") + lastParameter + ")\n"
- )
+ Utils.invoke(Caffe2DML.optimDir, dmlScript, "sgd_nesterov::", List[String](layer.extraWeight, layer.extraWeight + "_v"), "update",
+ List[String](layer.extraWeight, layer.dExtraWeight, getWeightLr(layer), momentum.toString, layer.extraWeight + "_v"), true)
if (layer.shouldUpdateBias)
- dmlScript
- .append("\t")
- .append(
- "[" + commaSep(layer.bias, layer.bias + "_v") + "] " +
- "= " + fn + "(" + commaSep(layer.bias, layer.dBias, getBiasLr(layer), momentum.toString, layer.bias + "_v") + lastParameter + ")\n"
- )
+ Utils.invoke(Caffe2DML.optimDir, dmlScript, "sgd_nesterov::", List[String](layer.bias, layer.bias + "_v"), "update",
+ List[String](layer.bias, layer.dBias, getBiasLr(layer), momentum.toString, layer.bias + "_v"), true)
}
def init(dmlScript: StringBuilder, layer: CaffeLayer): Unit = {
if (layer.shouldUpdateWeight) dmlScript.append(layer.weight + "_v = sgd_nesterov::init(" + layer.weight + ")\n")
@@ -372,4 +361,4 @@ class Nesterov(regularizationType:String = "L2", lambda: Double = 5e-04, momentu
if (layer.shouldUpdateBias) dmlScript.append(layer.bias + "_v = sgd_nesterov::init(" + layer.bias + ")\n")
}
def sourceFileName: String = "sgd_nesterov"
-}
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/systemml/blob/ef1945d7/src/main/scala/org/apache/sysml/api/dl/DMLGenerator.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/dl/DMLGenerator.scala b/src/main/scala/org/apache/sysml/api/dl/DMLGenerator.scala
index 60396f1..59c75ad 100644
--- a/src/main/scala/org/apache/sysml/api/dl/DMLGenerator.scala
+++ b/src/main/scala/org/apache/sysml/api/dl/DMLGenerator.scala
@@ -75,39 +75,6 @@ trait BaseDMLGenerator {
sum(dmlScript, rhsVars)
dmlScript.append("\n")
}
- def invoke(dmlScript: StringBuilder, namespace1: String, returnVariables: List[String], functionName: String, arguments: List[String]): Unit =
- invoke(dmlScript, namespace1, returnVariables, functionName, arguments, true)
- def invoke(dmlScript: StringBuilder, namespace1: String, returnVariables: List[String], functionName: String, arguments: List[String], appendNewLine: Boolean): Unit = {
- if (returnVariables.length == 0) throw new DMLRuntimeException("User-defined functions should have atleast one return value")
- if (returnVariables.length > 1) dmlScript.append("[")
- dmlScript.append(returnVariables(0))
- if (returnVariables.length > 1) {
- for (i <- 1 until returnVariables.length) {
- dmlScript.append(",").append(returnVariables(i))
- }
- dmlScript.append("]")
- }
- dmlScript.append(" = ")
- dmlScript.append(namespace1)
- dmlScript.append(functionName)
- dmlScript.append("(")
- if (arguments != null) {
- if (arguments.length != 0)
- dmlScript.append(arguments(0))
- if (arguments.length > 1) {
- for (i <- 1 until arguments.length) {
- dmlScript.append(",").append(arguments(i))
- }
- }
- }
- dmlScript.append(")")
- if (appendNewLine)
- dmlScript.append("\n")
- }
- def invoke(dmlScript: StringBuilder, namespace1: String, returnVariables: List[String], functionName: String, appendNewLine: Boolean, arguments: String*): Unit =
- invoke(dmlScript, namespace1, returnVariables, functionName, arguments.toList, appendNewLine)
- def invoke(dmlScript: StringBuilder, namespace1: String, returnVariables: List[String], functionName: String, arguments: String*): Unit =
- invoke(dmlScript, namespace1, returnVariables, functionName, arguments.toList, true)
def rightIndexing(dmlScript: StringBuilder, lhsVar:String, rhsVar: String, rl: String, ru: String, cl: String=null, cu: String=null): StringBuilder = {
dmlScript.append(lhsVar).append(" = ").append(rhsVar).append("[")
if (rl != null && ru != null) dmlScript.append(rl).append(":").append(ru)
@@ -279,6 +246,7 @@ trait DMLGenerator extends SourceDMLGenerator with NextBatchGenerator {
// Append source statements for layers as well as solver
source(net, solver, if (isTraining) Array[String]("l1_reg") else null)
source(net, solver, if (isTraining) Array[String]("l2_reg") else null)
+ source(dmlScript, numTabs, "util", Caffe2DML.nnDir)
if (isTraining) {
// Append external built-in function headers:
@@ -346,4 +314,4 @@ trait DMLGenerator extends SourceDMLGenerator with NextBatchGenerator {
def updateMeanVarianceForBatchNorm(net: CaffeNetwork, value: Boolean): Unit =
net.getLayers.filter(net.getCaffeLayer(_).isInstanceOf[BatchNorm]).map(net.getCaffeLayer(_).asInstanceOf[BatchNorm].update_mean_var = value)
-}
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/systemml/blob/ef1945d7/src/main/scala/org/apache/sysml/api/dl/Utils.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/dl/Utils.scala b/src/main/scala/org/apache/sysml/api/dl/Utils.scala
index 63aaf91..da53f65 100644
--- a/src/main/scala/org/apache/sysml/api/dl/Utils.scala
+++ b/src/main/scala/org/apache/sysml/api/dl/Utils.scala
@@ -39,6 +39,11 @@ import org.apache.sysml.runtime.matrix.data.MatrixBlock
import org.apache.sysml.api.mlcontext.MLContext
import org.apache.spark.SparkContext
import org.apache.spark.api.java.JavaSparkContext
+import org.apache.sysml.parser.ParserWrapper
+import org.apache.sysml.parser.dml.DMLParserWrapper
+import org.apache.sysml.parser.dml.InlineableMethods
+import java.util.ArrayList
+import scala.collection.JavaConverters._
object Utils {
// ---------------------------------------------------------------------------------------------
@@ -64,6 +69,57 @@ object Utils {
line = bufReader.readLine()
}
}
+
+ def readDMLScript(fileName:String):String = ParserWrapper.readDMLScript(fileName, Caffe2DML.LOG)
+ val inlineableMethods = new java.util.HashMap[String, java.util.HashMap[String, InlineableMethods]]()
+ def getInlineableMethod(sourceFilePath:String, namespace:String, fnName:String):InlineableMethods = {
+ if(inlineableMethods.contains(namespace))
+ return inlineableMethods.get(namespace).get(fnName)
+ else {
+ val ret = new DMLParserWrapper().getInlineableMethods(sourceFilePath, null, namespace, null)
+ inlineableMethods.put(namespace, ret)
+ return ret.get(fnName)
+ }
+ }
+
+ def invoke(dir:String, dmlScript: StringBuilder, namespace1: String, returnVariables: List[String], functionName: String, arguments: List[String], appendNewLine: Boolean): Unit = {
+ if(Caffe2DML.INLINE_NN_LIBRARY) {
+ // Caffe2DML.layerDir
+ // For now, donot inline recursively
+ val sourceFileName = if(namespace1.endsWith("::")) namespace1.substring(0, namespace1.length() - 2) else namespace1
+ val method = getInlineableMethod(dir + sourceFileName + ".dml", namespace1, functionName)
+ val generatedDML = method.getInlinedDML(new ArrayList[String](arguments.asJava), new ArrayList[String](returnVariables.asJava))
+ dmlScript.append(generatedDML)
+ dmlScript.append("\n")
+ //System.out.println(generatedDML)
+ return
+ }
+ if (returnVariables.length == 0) throw new DMLRuntimeException("User-defined functions should have atleast one return value")
+ if (returnVariables.length > 1) dmlScript.append("[")
+ dmlScript.append(returnVariables(0))
+ if (returnVariables.length > 1) {
+ for (i <- 1 until returnVariables.length) {
+ dmlScript.append(",").append(returnVariables(i))
+ }
+ dmlScript.append("]")
+ }
+ dmlScript.append(" = ")
+ dmlScript.append(namespace1)
+ dmlScript.append(functionName)
+ dmlScript.append("(")
+ if (arguments != null) {
+ if (arguments.length != 0)
+ dmlScript.append(arguments(0))
+ if (arguments.length > 1) {
+ for (i <- 1 until arguments.length) {
+ dmlScript.append(",").append(arguments(i))
+ }
+ }
+ }
+ dmlScript.append(")")
+ if (appendNewLine)
+ dmlScript.append("\n")
+ }
// ---------------------------------------------------------------------------------------------
def parseSolver(solverFilePath: String): CaffeSolver = parseSolver(readCaffeSolver(solverFilePath))
@@ -324,4 +380,4 @@ class Utils {
def saveCaffeModelFile(sc: JavaSparkContext, deployFilePath: String, caffeModelFilePath: String, outputDirectory: String, format: String): Unit =
Utils.saveCaffeModelFile(sc, deployFilePath, caffeModelFilePath, outputDirectory, format)
-}
+}
\ No newline at end of file