You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by in...@apache.org on 2017/12/02 00:57:17 UTC

[incubator-mxnet] branch master updated: Caffe to MXNet code translator (#8782)

This is an automated email from the ASF dual-hosted git repository.

indhub pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 3b3016e  Caffe to MXNet code translator (#8782)
3b3016e is described below

commit 3b3016e8439ca3ebaf8055a13f04117869e24eca
Author: Indhu Bharathi <in...@gmail.com>
AuthorDate: Fri Dec 1 16:57:12 2017 -0800

    Caffe to MXNet code translator (#8782)
    
    Caffe Translator is a migration tool that helps developers migrate their existing Caffe code to MXNet and continue further development using MXNet.
    
    Caffe Translator takes the training/validation prototxt and solver prototxt as input and produces MXNet Python code as output. The translated Python code uses MXNet Symbol and Module API to build the network, reads data from LMDB files, runs training and saves the trained model using the MXNet Module API.
    
    More info here: https://github.com/apache/incubator-mxnet/tree/master/tools/caffe_translator
    
    * Add Caffe Translator.
    Caffe Translator translates Caffe training/validation prototxt to MXNet Python code.
    
    * Minor bug fix.
    
    * [WIP] Convergence test for caffe translator
    
    * [WIP - Caffe Translator testing] Skeleton code for going thorough each test, translating, training, testing and generating report. Details are yet to be filled.
    
    * [WIP - Caffe Translator Test] Add code to create directory and invoke translator.
    
    * [WIP Caffe Converter tests] Train the converted network and collect logs.
    
    * [WIP - Caffe Translator test] Add code to search for a particular metric in log file and extract it.
    
    * - Bug fixes
    - Code refactoring
    
    * Add ‘num_examples’ parameter to generated CaffeDataIter.
    
    CaffeDataIter requires user to provide number of examples in the dataset using the num_examples parameter.
    However this information is not available in Caffe prototxt.
    
    This commit lets user add this information to the caffe prototxt as a comment. Caffe will parse it as comment and ignore it. Translator will pick the value and use it while generating CaffeDataIter.
    
    Example usage:
      data_param {
        source: "data/mnist/mnist_train_lmdb"
        #caffe2mxnet num_examples: 60000
        batch_size: 64
        backend: LMDB
      }
    
    * Update README.md
    
    * Update README.md
    
    * Update README.md
    
    * Create faq.md
    
    * Update faq.md
    
    * Update faq.md
    
    * Update faq.md
    
    * Update faq.md
    
    * Update faq.md
    
    * Update faq.md
    
    * Update README.md
    
    * Update README.md
    
    * Get test directory as input from command line.
    
    * Refactor code.
    
    * Don't print translated code to console.
    
    * Make sure "#caffemxnet num_examples: ..." doesn't go into iterator's prototxt.
    
    * Add sample test_dir
    
    * Add data directory
    
    * Add test configuration.
    Use the name test.cfg instead of test_description.txt
    
    * Create README.md
    
    * Update faq.md
    
    * Update faq.md
    
    * Update faq.md
    
    * Add script to convert .caffemodel to .params file.
    Only handles convolution and fuly connected layers. But can be easily extended to handle other layers.
    
    * Add ability to initialize networks with pretrained weights.
    
    * Add ability to convert weights for Deconvolution, BatchNorm and Scale.
    
    * Code cleanup.
    
    * Code cleanup. Remove the Constants class.
    
    * Code cleanup.
    
    * Code cleanup
    
    * Cleanup. Replace "indhub/mxnet" with "apache/incubator-mxnet".
    
    * Fix lint errors.
    
    * Fix lint issues.
    
    * Add few more solvers for lenet
    
    * Fix a regression.
    
    * Add Apache license.
    
    * Remove automated tests.
    
    * Create fat jar.
    Set version to 0.9.0.
    
    * Add license header.
    
    * Remove the 'all' suffix from the fat jar name.
    
    * Update README.md
    
    * Update README.md
    
    * Update README.md
    
    * Update README.md
    
    * Update gradle wrapper to use shadow plugin.
    
    * Initial work for signing and uploading artifacts to maven.
    
    * Include jar signature in the artifacts that get uploaded.
    
    * Add gradle properties required for signing and uploading artifacts to maven.
    
    * Print helpful error message when translator is not able to find the prototxt.
    
    * Print helpful error message in console and add a helpful comment in the generated code when an unknown lr policy is encountered.
    
    * Add final modifier when appropriate.
    
    * Rename GenHelper to GenerationHelper.
    
    * Expand imports
    
    * Some code style changes.
    
    * Update README.md
    
    * Update README.md
    
    * Create build_from_source.md
    
    * Update README.md
    
    * Update faq.md
    
    * Update README.md
    
    * Update build_from_source.md
    
    * Update README.md
    
    * Update README.md
    
    * Update faq.md
    
    * - Add source and doc jar.
    - Add ability to upload to staging repo.
    - Make sure all .asc files get uploaded.
    
    * Update README.md
    
    * Update maven pom with fields required for uploading artifacts.
    
    * Bump version to 0.9.1
    
    * Update README.md
    
    Add the list of command line parameters accepted by the Caffe Translator.
    
    * Update README.md
    
    Update repo location
    
    * After successful translation, print a message indicating the translation was success.
    
    * Update README.md
    
    * Update README.md
    
    * Update faq.md
    
    * Update README.md
    
    * Update README.md
    
    Provide a sample translated output
    
    * Update README.md
    
    Add links to examples to show how the following items looks like:
    1. Translated code.
    2. LMDB file.
    3. Model saved after training.
    
    * Rename #caffe2mxnet directive to #CaffeToMXNet to avoid any confusion with Caffe 2.
---
 tools/caffe_translator/README.md                   |  79 ++++
 tools/caffe_translator/build.gradle                | 152 +++++++
 tools/caffe_translator/build_from_source.md        |  21 +
 tools/caffe_translator/faq.md                      |  17 +
 tools/caffe_translator/gradle.properties           |  12 +
 .../gradle/wrapper/gradle-wrapper.jar              | Bin 0 -> 54731 bytes
 .../gradle/wrapper/gradle-wrapper.properties       |   5 +
 tools/caffe_translator/gradlew                     | 172 ++++++++
 tools/caffe_translator/gradlew.bat                 | 101 +++++
 .../scripts/convert_caffe_model.py                 | 121 ++++++
 tools/caffe_translator/settings.gradle             |   2 +
 .../io/mxnet/caffetranslator/CaffePrototxt.g4      |  74 ++++
 .../main/java/io/mxnet/caffetranslator/Config.java |  55 +++
 .../java/io/mxnet/caffetranslator/Converter.java   | 460 +++++++++++++++++++++
 .../mxnet/caffetranslator/CreateModelListener.java | 144 +++++++
 .../io/mxnet/caffetranslator/GenerationHelper.java | 195 +++++++++
 .../io/mxnet/caffetranslator/GeneratorOutput.java  |  35 ++
 .../java/io/mxnet/caffetranslator/Launcher.java    | 178 ++++++++
 .../main/java/io/mxnet/caffetranslator/Layer.java  | 141 +++++++
 .../java/io/mxnet/caffetranslator/MLModel.java     | 105 +++++
 .../io/mxnet/caffetranslator/ParserHelper.java     |  36 ++
 .../main/java/io/mxnet/caffetranslator/Solver.java |  98 +++++
 .../io/mxnet/caffetranslator/SolverListener.java   |  58 +++
 .../io/mxnet/caffetranslator/SymbolGenerator.java  |  29 ++
 .../caffetranslator/SymbolGeneratorFactory.java    |  53 +++
 .../main/java/io/mxnet/caffetranslator/Utils.java  |  42 ++
 .../generators/AccuracyMetricsGenerator.java       |  83 ++++
 .../caffetranslator/generators/BaseGenerator.java  |  60 +++
 .../generators/BatchNormGenerator.java             |  65 +++
 .../generators/ConcatGenerator.java                |  49 +++
 .../generators/ConvolutionGenerator.java           | 101 +++++
 .../generators/DeconvolutionGenerator.java         | 103 +++++
 .../generators/DropoutGenerator.java               |  43 ++
 .../generators/EltwiseGenerator.java               |  69 ++++
 .../caffetranslator/generators/FCGenerator.java    |  82 ++++
 .../generators/FlattenGenerator.java               |  49 +++
 .../generators/PermuteGenerator.java               |  48 +++
 .../generators/PluginIntLayerGenerator.java        |  80 ++++
 .../generators/PluginLayerHelper.java              |  63 +++
 .../generators/PluginLossGenerator.java            |  69 ++++
 .../generators/PoolingGenerator.java               |  86 ++++
 .../caffetranslator/generators/PowerGenerator.java |  51 +++
 .../caffetranslator/generators/ReluGenerator.java  |  44 ++
 .../caffetranslator/generators/ScaleGenerator.java |  66 +++
 .../generators/SoftmaxOutputGenerator.java         |  43 ++
 .../mxnet/caffetranslator/misc/CollectStats.java   |  73 ++++
 .../mxnet/caffetranslator/misc/StatsListener.java  | 103 +++++
 .../src/main/resources/templates/accuracy.st       |   2 +
 .../src/main/resources/templates/activation.st     |   1 +
 .../src/main/resources/templates/add.st            |   1 +
 .../src/main/resources/templates/batchnorm.st      |  14 +
 .../src/main/resources/templates/concat.st         |   1 +
 .../src/main/resources/templates/convolution.st    |   9 +
 .../src/main/resources/templates/deconvolution.st  |  10 +
 .../src/main/resources/templates/dropout.st        |   1 +
 .../src/main/resources/templates/fc.st             |   1 +
 .../src/main/resources/templates/flatten.st        |   1 +
 .../src/main/resources/templates/group.st          |   1 +
 .../src/main/resources/templates/imports.st        |   7 +
 .../src/main/resources/templates/init_params.st    |   7 +
 .../src/main/resources/templates/iterator.st       |  10 +
 .../src/main/resources/templates/logging.st        |  11 +
 .../src/main/resources/templates/lrn.st            |   1 +
 .../src/main/resources/templates/lrpolicy_exp.st   |   3 +
 .../src/main/resources/templates/lrpolicy_inv.st   |   3 +
 .../main/resources/templates/lrpolicy_multistep.st |   5 +
 .../src/main/resources/templates/lrpolicy_poly.st  |   3 +
 .../main/resources/templates/lrpolicy_sigmoid.st   |   3 +
 .../src/main/resources/templates/lrpolicy_step.st  |   4 +
 .../src/main/resources/templates/maxium.st         |   1 +
 .../main/resources/templates/metrics_classes.st    |  87 ++++
 .../src/main/resources/templates/mul.st            |   1 +
 .../src/main/resources/templates/opt_default.st    |  15 +
 .../src/main/resources/templates/opt_sgd.st        |  12 +
 .../main/resources/templates/param_initializer.st  |  12 +
 .../src/main/resources/templates/params_loader.st  |  13 +
 .../src/main/resources/templates/permute.st        |   1 +
 .../src/main/resources/templates/pooling.st        |  16 +
 .../src/main/resources/templates/power.st          |   1 +
 .../src/main/resources/templates/runner.st         |  57 +++
 .../src/main/resources/templates/softmaxoutput.st  |   3 +
 .../src/main/resources/templates/symbols.stg       |   7 +
 .../src/main/resources/templates/top_k_accuracy.st |   2 +
 .../src/main/resources/templates/var.st            |   1 +
 84 files changed, 4143 insertions(+)

diff --git a/tools/caffe_translator/README.md b/tools/caffe_translator/README.md
new file mode 100644
index 0000000..1d5a77c
--- /dev/null
+++ b/tools/caffe_translator/README.md
@@ -0,0 +1,79 @@
+# Caffe Translator
+Caffe Translator is a migration tool that helps developers migrate their existing Caffe code to MXNet and continue further development using MXNet. Note that this is different from the Caffe to MXNet model converter which is available [here](https://github.com/apache/incubator-mxnet/tree/master/tools/caffe_converter).
+
+Caffe Translator takes the training/validation prototxt ([example](https://github.com/BVLC/caffe/blob/master/examples/mnist/lenet_train_test.prototxt)) and solver prototxt ([example](https://github.com/BVLC/caffe/blob/master/examples/mnist/lenet_solver.prototxt)) as input and produces MXNet Python code ([example](https://www.caffetranslator.org/examples/lenet/lenet_translated.py)) as output. The translated Python code uses MXNet Symbol and Module API to build the network, reads data from [...]
+
+### How to use
+
+#### Get the translator:
+Download the Caffe Translator from maven [repository](https://mvnrepository.com/artifact/org.caffetranslator/caffe-translator) or [build](build_from_source.md) from source. Java Runtime Environment (JRE) is required to run the translator.
+
+#### Translate code:
+To translate `train_val.prototxt` and `solver.prototxt` to MXNet Python code, run the following command:
+```
+java -jar caffe-translator-<version>.jar --training-prototxt <train_val_prototxt_path> \
+    --solver <solver_prototxt_path> \
+    --output-file <output_file_path>
+```
+Example:
+```
+java -jar caffe-translator-0.9.1.jar --training-prototxt lenet_train_test.prototxt \
+    --solver lenet_solver.prototxt \
+    --output-file translated_code.py
+```
+
+Here is the list of command line parameters accepted by the Caffe Translator:
+- *training-prototxt*: specifies the path to the training/validation prototxt to be translated.
+- *solver-prototxt*: specifies the path to the solver prototxt to be translated.
+- *output-file*: specifies the file to write the translated output into.
+- *params-file* (optional): specifies the .caffemodel file to initialize parameters from.
+- *custom-data-layers* (optional): Specifies a comma-separated list of types of the custom data layers used in the prototxt. The translator will use [`CaffeDataIter`](https://mxnet.incubator.apache.org/how_to/caffe.html#use-io-caffedataiter) to translate these layers to MXNet.
+
+**Note:** Translated code uses [`CaffeDataIter`](https://mxnet.incubator.apache.org/how_to/caffe.html#use-io-caffedataiter) to read from LMDB files. `CaffeDataIter` requires the number of examples in LMDB file to be specified as a parameter. You can provide this information before translation using a `#CaffeToMXNet` directive like shown below:
+
+```
+  data_param {
+    source: "data/mnist/mnist_train_lmdb"
+    #CaffeToMXNet num_examples: 60000
+    batch_size: 64
+    backend: LMDB
+  }
+```
+
+#### Run the translated code:
+
+Following prerequisites are required to run the translated code:
+1. Caffe with MXNet interface ([Why?](faq.md#why_caffe) [How to build?](https://github.com/apache/incubator-mxnet/tree/master/plugin/caffe#install-caffe-with-mxnet-interface))
+2. MXNet with Caffe plugin ([How to build?](https://github.com/apache/incubator-mxnet/tree/master/plugin/caffe#compile-with-caffe))
+3. The dataset in LMDB format.
+
+Once prerequisites are installed, the translated Python code can be run like any other Python code:
+
+Example:
+```
+python translated_code.py
+```
+
+### What layers are supported?
+
+Caffe Translator can currently translate the following layers:
+
+- Accuracy and Top-k
+- Batch Normalization
+- Concat
+- Convolution
+- Data<sup>*</sup>
+- Deconvolution
+- Eltwise
+- Inner Product (Fully Connected layer)
+- Flatten
+- Permute
+- Pooling
+- Power
+- Relu
+- Scale<sup>*</sup>
+- SoftmaxOutput
+
+<sup>*</sup> Uses [CaffePlugin](https://github.com/apache/incubator-mxnet/tree/master/plugin/caffe)
+
+If you want Caffe Translator to translate a layer that is not in the above list, please create an [issue](https://github.com/apache/incubator-mxnet/issues/new).
diff --git a/tools/caffe_translator/build.gradle b/tools/caffe_translator/build.gradle
new file mode 100644
index 0000000..4206767
--- /dev/null
+++ b/tools/caffe_translator/build.gradle
@@ -0,0 +1,152 @@
+import org.gradle.api.artifacts.maven.MavenDeployment
+
+apply plugin: 'com.github.johnrengelman.shadow'
+apply plugin: 'java'
+
+apply plugin: 'antlr'
+apply plugin: 'application'
+
+apply plugin: 'maven'
+apply plugin: 'signing'
+
+group 'org.caffetranslator'
+version '0.9.1'
+
+def isReleaseBuild
+def repositoryUrl
+
+if(hasProperty("release")) {
+    isReleaseBuild = true
+    repositoryUrl = stagingRepositoryUrl
+} else if(hasProperty("CI")) {
+    repositoryUrl = snapshotRepositoryUrl
+    version += "-SNAPSHOT"
+}
+
+buildscript {
+    repositories {
+        jcenter()
+    }
+    dependencies {
+        classpath 'com.github.jengelman.gradle.plugins:shadow:2.0.1'
+    }
+}
+
+sourceCompatibility = 1.8
+
+repositories {
+    mavenCentral()
+}
+
+dependencies {
+    antlr "org.antlr:antlr4:$antlrVersion"
+    compile group: 'commons-cli', name: 'commons-cli', version: '1.4'
+    compileOnly 'org.projectlombok:lombok:1.16.18'
+    testCompile group: 'junit', name: 'junit', version: '4.12'
+}
+
+generateGrammarSource {
+    arguments += ['-visitor']
+}
+
+jar {
+    baseName = 'caffe-translator'
+    appendix = 'slim'
+    version = version
+    manifest {
+        attributes 'Main-Class': 'io.mxnet.caffetranslator.Launcher'
+    }
+}
+
+task javadocJar(type: Jar) {
+    classifier = 'javadoc'
+    from javadoc
+}
+
+task sourcesJar(type: Jar) {
+    classifier = 'sources'
+    from sourceSets.main.allSource
+}
+
+shadowJar {
+    baseName = 'caffe-translator'
+    classifier = ''
+    version = version
+}
+
+configurations {
+    releaseJars
+    ascSignatures
+}
+
+artifacts {
+    releaseJars shadowJar
+    releaseJars sourcesJar
+    releaseJars javadocJar
+}
+
+if(isReleaseBuild) {
+    signing {
+        sign configurations.releaseJars
+    }
+} else {
+    task signReleaseJars {
+        //no-op
+    }
+}
+
+uploadShadow {
+    repositories {
+        mavenDeployer {
+            beforeDeployment { MavenDeployment deployment ->
+                if(isReleaseBuild) {
+                    signing.signPom(deployment)
+                }
+                configurations.releaseJars.artifacts.each { ra ->
+                    def ascfile = file(ra.file.path + '.asc')
+                    def ascArtifact = project.artifacts.add('ascSignatures', ascfile) {
+                        classifier = ra.classifier
+                        extension = ra.extension + '.asc'
+                        type = ra.type + '.asc'
+                    }
+                    deployment.addArtifact(ra)
+                    deployment.addArtifact(ascArtifact)
+                }
+            }
+
+            repository(url: repositoryUrl) {
+                authentication(userName: ossrhUsername, password: ossrhPassword)
+            }
+
+            pom.project {
+                name 'Caffe Translator'
+                packaging 'jar'
+                description 'Translate Caffe code to MXNet Python code'
+                url 'http://caffetranslator.org'
+
+                licenses {
+                    license {
+                        name 'The Apache Software License, Version 2.0'
+                        url 'http://www.apache.org/licenses/LICENSE-2.0.txt'
+                        distribution 'repo'
+                    }
+                }
+
+                developers {
+                    developer {
+                        name 'Indu Bharathi'
+                        email 'indhub@apache.org'
+                    }
+                }
+
+                scm {
+                    connection 'scm:git:git://github.com:apache/incubator-mxnet.git'
+                    developerConnection 'scm:git:git@github.com:apache/incubator-mxnet.git'
+                    url 'https://github.com/apache/incubator-mxnet.git'
+                }
+            }
+        }
+    }
+}
+
+mainClassName = "io.mxnet.caffetranslator.Launcher"
diff --git a/tools/caffe_translator/build_from_source.md b/tools/caffe_translator/build_from_source.md
new file mode 100644
index 0000000..81480fa
--- /dev/null
+++ b/tools/caffe_translator/build_from_source.md
@@ -0,0 +1,21 @@
+### Build Caffe Translator from source
+
+#### Prerequisites:
+- JDK
+
+#### Instructions to build
+
+Step 1: Clone the code:
+```
+git clone https://github.com/apache/incubator-mxnet.git mxnet
+```
+Step 2: CD to CaffeTranslator directory
+```
+cd mxnet/tools/caffe_translator/
+```
+Step 3: Build
+```
+./gradlew build
+```
+
+Caffe Translator will be built at `build/libs/caffe-translator-<version>.jar`
diff --git a/tools/caffe_translator/faq.md b/tools/caffe_translator/faq.md
new file mode 100644
index 0000000..81cdfb9
--- /dev/null
+++ b/tools/caffe_translator/faq.md
@@ -0,0 +1,17 @@
+### Frequently asked questions
+
+[**Why is Caffe required to run the translated code?**](#why_caffe)
+
+There is a couple of reasons why Caffe is required to run the translated code:
+
+1. The translator does not convert Caffe data layer to native MXNet code because MXNet cannot read from LMDB files. Translator instead generates code that uses [`CaffeDataIter`](https://mxnet.incubator.apache.org/how_to/caffe.html#use-io-caffedataiter) which can read LMDB files. `CaffeDataIter` needs Caffe to run.
+
+2. If the Caffe code to be translated uses custom layers, or layers that don't have equivalent MXNet layers, the translator will generate code that will use [CaffeOp](https://mxnet.incubator.apache.org/how_to/caffe.html#use-sym-caffeop). CaffeOp needs Caffe to run.
+
+[**What version of Caffe prototxt can the translator translate?**](#what_version_of_prototxt)
+
+Caffe Translator supports the `proto2` syntax.
+
+[**Can the translator translate Caffe 2 code?**](#caffe_2_support)
+
+No. At the moment, only Caffe is supported.
diff --git a/tools/caffe_translator/gradle.properties b/tools/caffe_translator/gradle.properties
new file mode 100644
index 0000000..f115eaa
--- /dev/null
+++ b/tools/caffe_translator/gradle.properties
@@ -0,0 +1,12 @@
+antlrVersion=4.7
+
+signing.keyId=<key-id>
+signing.password=<key-password>
+signing.secretKeyRingFile=<path-to-key-ring-file>
+
+snapshotRepositoryUrl=https://oss.sonatype.org/content/repositories/snapshots
+stagingRepositoryUrl=https://oss.sonatype.org/service/local/staging/deploy/maven2
+
+ossrhUsername=<ossrh-username>
+ossrhPassword=<ossrh_password>
+
diff --git a/tools/caffe_translator/gradle/wrapper/gradle-wrapper.jar b/tools/caffe_translator/gradle/wrapper/gradle-wrapper.jar
new file mode 100644
index 0000000..6b6ea3a
Binary files /dev/null and b/tools/caffe_translator/gradle/wrapper/gradle-wrapper.jar differ
diff --git a/tools/caffe_translator/gradle/wrapper/gradle-wrapper.properties b/tools/caffe_translator/gradle/wrapper/gradle-wrapper.properties
new file mode 100644
index 0000000..0e680f3
--- /dev/null
+++ b/tools/caffe_translator/gradle/wrapper/gradle-wrapper.properties
@@ -0,0 +1,5 @@
+distributionBase=GRADLE_USER_HOME
+distributionPath=wrapper/dists
+zipStoreBase=GRADLE_USER_HOME
+zipStorePath=wrapper/dists
+distributionUrl=https\://services.gradle.org/distributions/gradle-4.3.1-bin.zip
diff --git a/tools/caffe_translator/gradlew b/tools/caffe_translator/gradlew
new file mode 100755
index 0000000..cccdd3d
--- /dev/null
+++ b/tools/caffe_translator/gradlew
@@ -0,0 +1,172 @@
+#!/usr/bin/env sh
+
+##############################################################################
+##
+##  Gradle start up script for UN*X
+##
+##############################################################################
+
+# Attempt to set APP_HOME
+# Resolve links: $0 may be a link
+PRG="$0"
+# Need this for relative symlinks.
+while [ -h "$PRG" ] ; do
+    ls=`ls -ld "$PRG"`
+    link=`expr "$ls" : '.*-> \(.*\)$'`
+    if expr "$link" : '/.*' > /dev/null; then
+        PRG="$link"
+    else
+        PRG=`dirname "$PRG"`"/$link"
+    fi
+done
+SAVED="`pwd`"
+cd "`dirname \"$PRG\"`/" >/dev/null
+APP_HOME="`pwd -P`"
+cd "$SAVED" >/dev/null
+
+APP_NAME="Gradle"
+APP_BASE_NAME=`basename "$0"`
+
+# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
+DEFAULT_JVM_OPTS=""
+
+# Use the maximum available, or set MAX_FD != -1 to use that value.
+MAX_FD="maximum"
+
+warn () {
+    echo "$*"
+}
+
+die () {
+    echo
+    echo "$*"
+    echo
+    exit 1
+}
+
+# OS specific support (must be 'true' or 'false').
+cygwin=false
+msys=false
+darwin=false
+nonstop=false
+case "`uname`" in
+  CYGWIN* )
+    cygwin=true
+    ;;
+  Darwin* )
+    darwin=true
+    ;;
+  MINGW* )
+    msys=true
+    ;;
+  NONSTOP* )
+    nonstop=true
+    ;;
+esac
+
+CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar
+
+# Determine the Java command to use to start the JVM.
+if [ -n "$JAVA_HOME" ] ; then
+    if [ -x "$JAVA_HOME/jre/sh/java" ] ; then
+        # IBM's JDK on AIX uses strange locations for the executables
+        JAVACMD="$JAVA_HOME/jre/sh/java"
+    else
+        JAVACMD="$JAVA_HOME/bin/java"
+    fi
+    if [ ! -x "$JAVACMD" ] ; then
+        die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME
+
+Please set the JAVA_HOME variable in your environment to match the
+location of your Java installation."
+    fi
+else
+    JAVACMD="java"
+    which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
+
+Please set the JAVA_HOME variable in your environment to match the
+location of your Java installation."
+fi
+
+# Increase the maximum file descriptors if we can.
+if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then
+    MAX_FD_LIMIT=`ulimit -H -n`
+    if [ $? -eq 0 ] ; then
+        if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then
+            MAX_FD="$MAX_FD_LIMIT"
+        fi
+        ulimit -n $MAX_FD
+        if [ $? -ne 0 ] ; then
+            warn "Could not set maximum file descriptor limit: $MAX_FD"
+        fi
+    else
+        warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT"
+    fi
+fi
+
+# For Darwin, add options to specify how the application appears in the dock
+if $darwin; then
+    GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\""
+fi
+
+# For Cygwin, switch paths to Windows format before running java
+if $cygwin ; then
+    APP_HOME=`cygpath --path --mixed "$APP_HOME"`
+    CLASSPATH=`cygpath --path --mixed "$CLASSPATH"`
+    JAVACMD=`cygpath --unix "$JAVACMD"`
+
+    # We build the pattern for arguments to be converted via cygpath
+    ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null`
+    SEP=""
+    for dir in $ROOTDIRSRAW ; do
+        ROOTDIRS="$ROOTDIRS$SEP$dir"
+        SEP="|"
+    done
+    OURCYGPATTERN="(^($ROOTDIRS))"
+    # Add a user-defined pattern to the cygpath arguments
+    if [ "$GRADLE_CYGPATTERN" != "" ] ; then
+        OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)"
+    fi
+    # Now convert the arguments - kludge to limit ourselves to /bin/sh
+    i=0
+    for arg in "$@" ; do
+        CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -`
+        CHECK2=`echo "$arg"|egrep -c "^-"`                                 ### Determine if an option
+
+        if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then                    ### Added a condition
+            eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"`
+        else
+            eval `echo args$i`="\"$arg\""
+        fi
+        i=$((i+1))
+    done
+    case $i in
+        (0) set -- ;;
+        (1) set -- "$args0" ;;
+        (2) set -- "$args0" "$args1" ;;
+        (3) set -- "$args0" "$args1" "$args2" ;;
+        (4) set -- "$args0" "$args1" "$args2" "$args3" ;;
+        (5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;;
+        (6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;;
+        (7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;;
+        (8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;;
+        (9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;;
+    esac
+fi
+
+# Escape application args
+save () {
+    for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done
+    echo " "
+}
+APP_ARGS=$(save "$@")
+
+# Collect all arguments for the java command, following the shell quoting and substitution rules
+eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS"
+
+# by default we should be in the correct project dir, but when run from Finder on Mac, the cwd is wrong
+if [ "$(uname)" = "Darwin" ] && [ "$HOME" = "$PWD" ]; then
+  cd "$(dirname "$0")"
+fi
+
+exec "$JAVACMD" "$@"
diff --git a/tools/caffe_translator/gradlew.bat b/tools/caffe_translator/gradlew.bat
new file mode 100644
index 0000000..a1c49a3
--- /dev/null
+++ b/tools/caffe_translator/gradlew.bat
@@ -0,0 +1,101 @@
+rem Licensed to the Apache Software Foundation (ASF) under one
+rem or more contributor license agreements.  See the NOTICE file
+rem distributed with this work for additional information
+rem regarding copyright ownership.  The ASF licenses this file
+rem to you under the Apache License, Version 2.0 (the
+rem "License"); you may not use this file except in compliance
+rem with the License.  You may obtain a copy of the License at
+rem
+rem   http://www.apache.org/licenses/LICENSE-2.0
+rem
+rem Unless required by applicable law or agreed to in writing,
+rem software distributed under the License is distributed on an
+rem "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+rem KIND, either express or implied.  See the License for the
+rem specific language governing permissions and limitations
+rem under the License.
+
+@if "%DEBUG%" == "" @echo off
+@rem ##########################################################################
+@rem
+@rem  Gradle startup script for Windows
+@rem
+@rem ##########################################################################
+
+@rem Set local scope for the variables with windows NT shell
+if "%OS%"=="Windows_NT" setlocal
+
+set DIRNAME=%~dp0
+if "%DIRNAME%" == "" set DIRNAME=.
+set APP_BASE_NAME=%~n0
+set APP_HOME=%DIRNAME%
+
+@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
+set DEFAULT_JVM_OPTS=
+
+@rem Find java.exe
+if defined JAVA_HOME goto findJavaFromJavaHome
+
+set JAVA_EXE=java.exe
+%JAVA_EXE% -version >NUL 2>&1
+if "%ERRORLEVEL%" == "0" goto init
+
+echo.
+echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
+echo.
+echo Please set the JAVA_HOME variable in your environment to match the
+echo location of your Java installation.
+
+goto fail
+
+:findJavaFromJavaHome
+set JAVA_HOME=%JAVA_HOME:"=%
+set JAVA_EXE=%JAVA_HOME%/bin/java.exe
+
+if exist "%JAVA_EXE%" goto init
+
+echo.
+echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME%
+echo.
+echo Please set the JAVA_HOME variable in your environment to match the
+echo location of your Java installation.
+
+goto fail
+
+:init
+@rem Get command-line arguments, handling Windows variants
+
+if not "%OS%" == "Windows_NT" goto win9xME_args
+
+:win9xME_args
+@rem Slurp the command line arguments.
+set CMD_LINE_ARGS=
+set _SKIP=2
+
+:win9xME_args_slurp
+if "x%~1" == "x" goto execute
+
+set CMD_LINE_ARGS=%*
+
+:execute
+@rem Setup the command line
+
+set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar
+
+@rem Execute Gradle
+"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS%
+
+:end
+@rem End local scope for the variables with windows NT shell
+if "%ERRORLEVEL%"=="0" goto mainEnd
+
+:fail
+rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of
+rem the _cmd.exe /c_ return code!
+if  not "" == "%GRADLE_EXIT_CONSOLE%" exit 1
+exit /b 1
+
+:mainEnd
+if "%OS%"=="Windows_NT" endlocal
+
+:omega
diff --git a/tools/caffe_translator/scripts/convert_caffe_model.py b/tools/caffe_translator/scripts/convert_caffe_model.py
new file mode 100644
index 0000000..d7f13c4
--- /dev/null
+++ b/tools/caffe_translator/scripts/convert_caffe_model.py
@@ -0,0 +1,121 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+"""Script to convert Caffe .modelfile to MXNet .params file"""
+from __future__ import print_function
+import argparse
+import mxnet as mx
+
+import caffe
+from caffe.proto import caffe_pb2
+
+class CaffeModelConverter(object):
+    """Converts Caffe .modelfile to MXNet .params file"""
+    def __init__(self):
+        self.dict_param = {}
+        self.layers = None
+
+    def add_param(self, param_name, layer_index, blob_index):
+        """Add a param to the .params file"""
+        blobs = self.layers[layer_index].blobs
+        self.dict_param[param_name] = mx.nd.array(caffe.io.blobproto_to_array(blobs[blob_index]))
+
+    def add_arg_param(self, param_name, layer_index, blob_index):
+        """Add an arg param to .params file. Example: weights of a fully connected layer."""
+        self.add_param('arg:%s' % param_name, layer_index, blob_index)
+
+    def add_aux_param(self, param_name, layer_index, blob_index):
+        """Add an aux param to .params file. Example: moving_mean in BatchNorm layer """
+        self.add_param('aux:%s' % param_name, layer_index, blob_index)
+
+    def add_optional_arg_param(self, param_name, layer_index, blob_index):
+        """Add an arg param. If there is no such param in .caffemodel fie, silently ignore it."""
+        blobs = self.layers[layer_index].blobs
+        if blob_index < len(blobs):
+            self.add_arg_param(param_name, layer_index, blob_index)
+
+    def convert(self, caffemodel_path, outmodel_path):
+        """Convert a Caffe .caffemodel file to MXNet .params file"""
+        net_param = caffe_pb2.NetParameter()
+        with open(caffemodel_path, 'rb') as caffe_model_file:
+            net_param.ParseFromString(caffe_model_file.read())
+
+        layers = net_param.layer
+        self.layers = layers
+
+        for idx, layer in enumerate(layers):
+            layer_name = str(layer.name)
+
+            if layer.blobs:
+
+                # If this is a layer that has only weight and bias as parameter
+                if layer.type == 'Convolution' or layer.type == 'InnerProduct' \
+                        or layer.type == 'Deconvolution':
+
+                    # Add weight and bias to the dictionary
+                    self.add_arg_param('%s_weight' % layer_name, layer_index=idx, blob_index=0)
+                    self.add_optional_arg_param('%s_bias' % layer_name, layer_index=idx,
+                                                blob_index=1)
+
+                elif layer.type == 'BatchNorm':
+
+                    gamma_param_name = '%s_gamma' % layer_name
+                    beta_param_name = '%s_beta' % layer_name
+
+                    next_layer = layers[idx + 1]
+
+                    if next_layer.type == 'Scale':
+                        # If next layer is scale layer, get gamma and beta from there
+                        self.add_arg_param(gamma_param_name, layer_index=idx+1, blob_index=0)
+                        self.add_arg_param(beta_param_name, layer_index=idx+1, blob_index=1)
+
+                    mean_param_name = '%s_moving_mean' % layer_name
+                    var_param_name = '%s_moving_var' % layer_name
+
+                    self.add_aux_param(mean_param_name, layer_index=idx, blob_index=0)
+                    self.add_aux_param(var_param_name, layer_index=idx, blob_index=1)
+
+                elif layer.type == 'Scale':
+
+                    prev_layer = layers[idx - 1]
+
+                    if prev_layer.type == 'BatchNorm':
+                        continue
+                    else:
+                        # Use the naming convention used by CaffeOp
+                        self.add_arg_param('%s_0_weight' % layer_name, layer_index=idx,
+                                           blob_index=0)
+                        self.add_optional_arg_param('%s_1_bias' % layer_name,
+                                                    layer_index=idx, blob_index=1)
+
+        mx.nd.save(outmodel_path, self.dict_param)
+
+def main():
+    """Read .caffemodel path and .params path as input from command line
+    and use CaffeModelConverter to do the conversion"""
+    parser = argparse.ArgumentParser(description='.caffemodel to MXNet .params converter.')
+    parser.add_argument('caffemodel', help='Path to the .caffemodel file to convert.')
+    parser.add_argument('output_file_name', help='Name of the output .params file.')
+
+    args = parser.parse_args()
+
+    converter = CaffeModelConverter()
+    converter.convert(args.caffemodel, args.output_file_name)
+
+if __name__ == '__main__':
+    main()
diff --git a/tools/caffe_translator/settings.gradle b/tools/caffe_translator/settings.gradle
new file mode 100644
index 0000000..c7259d1
--- /dev/null
+++ b/tools/caffe_translator/settings.gradle
@@ -0,0 +1,2 @@
+rootProject.name = 'caffetranslator'
+
diff --git a/tools/caffe_translator/src/main/antlr/io/mxnet/caffetranslator/CaffePrototxt.g4 b/tools/caffe_translator/src/main/antlr/io/mxnet/caffetranslator/CaffePrototxt.g4
new file mode 100644
index 0000000..1009382
--- /dev/null
+++ b/tools/caffe_translator/src/main/antlr/io/mxnet/caffetranslator/CaffePrototxt.g4
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file CaffePrototxt.g4
+ * \brief Grammar to parse Caffe prototxt
+ */
+
+grammar CaffePrototxt;
+
+@header {
+package io.mxnet.caffetranslator;
+}
+
+
+prototxt: name layer+;
+
+solver: pair+;
+
+name: ID COLON STRING;
+
+layer: ID object;
+
+pair: ID COLON? value;
+
+value: object                   #valueObject
+     | (STRING | NUMBER | ID)   #valueLeaf
+     ;
+
+object: LPAREN pair+ RPAREN;
+
+LPAREN: '{';
+
+RPAREN: '}';
+
+COLON: ':';
+
+NUMBER : '-'? ('.' DIGIT+ | DIGIT+ ('.' DIGIT*)? ) Exponent?;
+fragment
+DIGIT : [0-9] ;
+fragment
+Exponent : ('e'|'E') ('+'|'-')? ('0'..'9')+ ;
+
+ID: LETTER (LETTER|DIGIT)*;
+
+fragment
+LETTER      :   [a-zA-Z\u0080-\u00FF_] ;
+
+STRING      :   '"' ('\\"'|.)*? '"'
+            |   '\'' ('\\\''|.)*? '\'' ;
+
+WS  :   [ \t]+ -> channel(HIDDEN) ;
+
+NL  :   [\n\r]+ -> channel(HIDDEN) ;
+
+COMMENT :  '#' ~( '\r' | '\n' )* {!getText().startsWith("#CaffeToMXNet")}? -> skip;
+
+CAFFE2MXNET: '#CaffeToMXNet' -> skip;
diff --git a/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/Config.java b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/Config.java
new file mode 100644
index 0000000..006e133
--- /dev/null
+++ b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/Config.java
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file Config.java
+ * \brief Helper class to store config
+ */
+
+package io.mxnet.caffetranslator;
+
+import java.util.List;
+import java.util.Vector;
+
+public class Config {
+
+    private static final Config instance = new Config();
+
+    public static Config getInstance() {
+        return instance;
+    }
+
+    private Config() {
+        if (instance != null) {
+            throw new IllegalStateException("Already instantiated");
+        }
+
+        customDataLayers = new Vector<String>();
+    }
+
+    public List<String> getCustomDataLayers() {
+        return customDataLayers;
+    }
+
+    public void addCustomDataLayer(String name) {
+        customDataLayers.add(name);
+    }
+
+    private Vector<String> customDataLayers;
+}
diff --git a/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/Converter.java b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/Converter.java
new file mode 100644
index 0000000..90ed9d2
--- /dev/null
+++ b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/Converter.java
@@ -0,0 +1,460 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file Converter.java
+ * \brief Convert Caffe prototxt to MXNet Python code
+ */
+
+package io.mxnet.caffetranslator;
+
+import io.mxnet.caffetranslator.generators.*;
+import lombok.Setter;
+import org.antlr.v4.runtime.CharStream;
+import org.antlr.v4.runtime.CharStreams;
+import org.antlr.v4.runtime.CommonTokenStream;
+import org.stringtemplate.v4.ST;
+import org.stringtemplate.v4.STGroup;
+import org.stringtemplate.v4.STRawGroupDir;
+
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.nio.charset.StandardCharsets;
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+public class Converter {
+
+    private final String trainPrototxt, solverPrototxt;
+    private final MLModel mlModel;
+    private final STGroup stGroup;
+    private final SymbolGeneratorFactory generators;
+    private final String NL;
+    private final GenerationHelper gh;
+    @Setter
+
+    private String paramsFilePath;
+    private Solver solver;
+
+    Converter(String trainPrototxt, String solverPrototxt) {
+        this.trainPrototxt = trainPrototxt;
+        this.solverPrototxt = solverPrototxt;
+        this.mlModel = new MLModel();
+        this.stGroup = new STRawGroupDir("templates");
+        this.generators = SymbolGeneratorFactory.getInstance();
+        NL = System.getProperty("line.separator");
+        gh = new GenerationHelper();
+        addGenerators();
+    }
+
+    private void addGenerators() {
+        generators.addGenerator("Convolution", new ConvolutionGenerator());
+        generators.addGenerator("Deconvolution", new DeconvolutionGenerator());
+        generators.addGenerator("Pooling", new PoolingGenerator());
+        generators.addGenerator("InnerProduct", new FCGenerator());
+        generators.addGenerator("ReLU", new ReluGenerator());
+        generators.addGenerator("SoftmaxWithLoss", new SoftmaxOutputGenerator());
+        generators.addGenerator("PluginIntLayerGenerator", new PluginIntLayerGenerator());
+        generators.addGenerator("CaffePluginLossLayer", new PluginLossGenerator());
+        generators.addGenerator("Permute", new PermuteGenerator());
+        generators.addGenerator("Concat", new ConcatGenerator());
+        generators.addGenerator("BatchNorm", new BatchNormGenerator());
+        generators.addGenerator("Power", new PowerGenerator());
+        generators.addGenerator("Eltwise", new EltwiseGenerator());
+        generators.addGenerator("Flatten", new FlattenGenerator());
+        generators.addGenerator("Dropout", new DropoutGenerator());
+        generators.addGenerator("Scale", new ScaleGenerator());
+    }
+
+    public boolean parseTrainingPrototxt() {
+
+        CharStream cs = null;
+        try {
+            FileInputStream fis = new FileInputStream(new File(trainPrototxt));
+            cs = CharStreams.fromStream(fis, StandardCharsets.UTF_8);
+        } catch (IOException e) {
+            System.err.println("Unable to read prototxt: " + trainPrototxt);
+            return false;
+        }
+
+        CaffePrototxtLexer lexer = new CaffePrototxtLexer(cs);
+
+        CommonTokenStream tokens = new CommonTokenStream(lexer);
+        CaffePrototxtParser parser = new CaffePrototxtParser(tokens);
+
+        CreateModelListener modelCreator = new CreateModelListener(parser, mlModel);
+        parser.addParseListener(modelCreator);
+        parser.prototxt();
+
+        return true;
+    }
+
+    public boolean parseSolverPrototxt() {
+        solver = new Solver(solverPrototxt);
+        return solver.parsePrototxt();
+    }
+
+    public String generateMXNetCode() {
+        if (!parseTrainingPrototxt()) {
+            return "";
+        }
+
+        if (!parseSolverPrototxt()) {
+            return "";
+        }
+
+        StringBuilder code = new StringBuilder();
+
+        code.append(generateImports());
+        code.append(System.lineSeparator());
+
+        code.append(generateLogger());
+        code.append(System.lineSeparator());
+
+        code.append(generateParamInitializer());
+        code.append(System.lineSeparator());
+
+        code.append(generateMetricsClasses());
+        code.append(System.lineSeparator());
+
+        if (paramsFilePath != null) {
+            code.append(generateParamsLoader());
+            code.append(System.lineSeparator());
+        }
+
+        // Convert data layers
+        code.append(generateIterators());
+
+        // Generate variables for data and label
+        code.append(generateInputVars());
+
+        // Convert non data layers
+        List<Layer> layers = mlModel.getNonDataLayers();
+
+        for (int layerIndex = 0; layerIndex < layers.size(); ) {
+            Layer layer = layers.get(layerIndex);
+            SymbolGenerator generator = generators.getGenerator(layer.getType());
+
+            // If the translator cannot translate this layer to an MXNet layer,
+            // use CaffeOp or CaffeLoss instead.
+            if (generator == null) {
+                if (layer.getType().toLowerCase().endsWith("loss")) {
+                    generator = generators.getGenerator("CaffePluginLossLayer");
+                } else {
+                    generator = generators.getGenerator("PluginIntLayerGenerator");
+                }
+            }
+
+            GeneratorOutput out = generator.generate(layer, mlModel);
+            String segment = out.code;
+            code.append(segment);
+            code.append(NL);
+
+            layerIndex += out.numLayersTranslated;
+        }
+
+        String loss = getLoss(mlModel, code);
+
+        String evalMetric = generateValidationMetrics(mlModel);
+        code.append(evalMetric);
+
+        String runner = generateRunner(loss);
+        code.append(runner);
+
+        return code.toString();
+    }
+
+    private String generateLogger() {
+        ST st = gh.getTemplate("logging");
+        st.add("name", mlModel.getName());
+        return st.render();
+    }
+
+    private String generateRunner(String loss) {
+        ST st = gh.getTemplate("runner");
+        st.add("max_iter", solver.getProperty("max_iter"));
+        st.add("stepsize", solver.getProperty("stepsize"));
+        st.add("snapshot", solver.getProperty("snapshot"));
+        st.add("test_interval", solver.getProperty("test_interval"));
+        st.add("test_iter", solver.getProperty("test_iter"));
+        st.add("snapshot_prefix", solver.getProperty("snapshot_prefix"));
+
+        st.add("train_data_itr", getIteratorName("TRAIN"));
+        st.add("test_data_itr", getIteratorName("TEST"));
+
+        String context = solver.getProperty("solver_mode", "cpu").toLowerCase();
+        context = String.format("mx.%s()", context);
+        st.add("ctx", context);
+
+        st.add("loss", loss);
+
+        st.add("data_names", getDataNames());
+        st.add("label_names", getLabelNames());
+
+        st.add("init_params", generateInitializer());
+
+        st.add("init_optimizer", generateOptimizer());
+        st.add("gamma", solver.getProperty("gamma"));
+        st.add("power", solver.getProperty("power"));
+        st.add("lr_update", generateLRUpdate());
+
+        return st.render();
+    }
+
+    private String generateParamInitializer() {
+        return gh.getTemplate("param_initializer").render();
+    }
+
+    private String generateMetricsClasses() {
+        ST st = gh.getTemplate("metrics_classes");
+
+        String display = solver.getProperty("display");
+        String average_loss = solver.getProperty("average_loss");
+
+        if (display != null) {
+            st.add("display", display);
+        }
+
+        if (average_loss != null) {
+            st.add("average_loss", average_loss);
+        }
+
+        return st.render();
+    }
+
+    private String generateParamsLoader() {
+        return gh.getTemplate("params_loader").render();
+    }
+
+    private String getLoss(MLModel model, StringBuilder out) {
+        List<String> losses = new ArrayList<>();
+        for (Layer layer : model.getLayerList()) {
+            if (layer.getType().toLowerCase().endsWith("loss")) {
+                losses.add(gh.getVarname(layer.getTop()));
+            }
+        }
+
+        if (losses.size() == 1) {
+            return losses.get(0);
+        } else if (losses.size() > 1) {
+            String loss_var = "combined_loss";
+            ST st = gh.getTemplate("group");
+            st.add("var", loss_var);
+            st.add("symbols", losses);
+            out.append(st.render());
+            return loss_var;
+        } else {
+            System.err.println("No loss found");
+            return "unknown_loss";
+        }
+    }
+
+    private String generateLRUpdate() {
+        String code;
+        String lrPolicy = solver.getProperty("lr_policy", "fixed").toLowerCase();
+        ST st;
+        switch (lrPolicy) {
+            case "fixed":
+                // lr stays fixed. No update needed
+                code = "";
+                break;
+            case "multistep":
+                st = gh.getTemplate("lrpolicy_multistep");
+                st.add("steps", solver.getProperties("stepvalue"));
+                code = st.render();
+                break;
+            case "step":
+            case "exp":
+            case "inv":
+            case "poly":
+            case "sigmoid":
+                st = gh.getTemplate("lrpolicy_" + lrPolicy);
+                code = st.render();
+                break;
+            default:
+                String message = "Unknown lr_policy: " + lrPolicy;
+                System.err.println(message);
+                code = "# " + message + System.lineSeparator();
+                break;
+        }
+        return Utils.indent(code, 2, true, 4);
+    }
+
+    private String generateValidationMetrics(MLModel mlModel) {
+        return new AccuracyMetricsGenerator().generate(mlModel);
+    }
+
+    private String generateOptimizer() {
+        String caffeOptimizer = solver.getProperty("type", "sgd").toLowerCase();
+        ST st;
+
+        String lr = solver.getProperty("base_lr");
+        String momentum = solver.getProperty("momentum", "0.9");
+        String wd = solver.getProperty("weight_decay", "0.0005");
+
+        switch (caffeOptimizer) {
+            case "adadelta":
+                st = gh.getTemplate("opt_default");
+                st.add("opt_name", "AdaDelta");
+                st.add("epsilon", solver.getProperty("delta"));
+                break;
+            case "adagrad":
+                st = gh.getTemplate("opt_default");
+                st.add("opt_name", "AdaGrad");
+                break;
+            case "adam":
+                st = gh.getTemplate("opt_default");
+                st.add("opt_name", "Adam");
+                break;
+            case "nesterov":
+                st = gh.getTemplate("opt_sgd");
+                st.add("opt_name", "NAG");
+                st.add("momentum", momentum);
+                break;
+            case "rmsprop":
+                st = gh.getTemplate("opt_default");
+                st.add("opt_name", "RMSProp");
+                break;
+            default:
+                if (!caffeOptimizer.equals("sgd")) {
+                    System.err.println("Unknown optimizer. Will use SGD instead.");
+                }
+
+                st = gh.getTemplate("opt_sgd");
+                st.add("opt_name", "SGD");
+                st.add("momentum", momentum);
+                break;
+        }
+        st.add("lr", lr);
+        st.add("wd", wd);
+
+        return st.render();
+    }
+
+    private String generateInitializer() {
+        ST st = gh.getTemplate("init_params");
+        st.add("params_file", paramsFilePath);
+        return st.render();
+    }
+
+    private String generateImports() {
+        return gh.getTemplate("imports").render();
+    }
+
+    private StringBuilder generateIterators() {
+        StringBuilder code = new StringBuilder();
+
+        for (Layer layer : mlModel.getDataLayers()) {
+            String iterator = generateIterator(layer);
+            code.append(iterator);
+        }
+
+        return code;
+    }
+
+    private String getIteratorName(String phase) {
+        for (Layer layer : mlModel.getDataLayers()) {
+            String layerPhase = layer.getAttr("include.phase", phase);
+            if (phase.equalsIgnoreCase(layerPhase)) {
+                return layerPhase.toLowerCase() + "_" + layer.getName() + "_" + "itr";
+            }
+        }
+        return null;
+    }
+
+    private List<String> getDataNames() {
+        return getDataNames(0);
+    }
+
+    private List<String> getLabelNames() {
+        return getDataNames(1);
+    }
+
+    private List<String> getDataNames(int topIndex) {
+        List<String> dataList = new ArrayList<String>();
+        for (Layer layer : mlModel.getDataLayers()) {
+            if (layer.getAttr("include.phase").equalsIgnoreCase("train")) {
+                String dataName = layer.getTops().get(topIndex);
+                if (dataName != null) {
+                    dataList.add(String.format("'%s'", dataName));
+                }
+            }
+        }
+        return dataList;
+    }
+
+    private StringBuilder generateInputVars() {
+        StringBuilder code = new StringBuilder();
+
+        Set<String> tops = new HashSet<String>();
+
+        for (Layer layer : mlModel.getDataLayers())
+            for (String top : layer.getTops())
+                tops.add(top);
+
+        for (String top : tops)
+            code.append(gh.generateVar(gh.getVarname(top), top, null, null, null, null));
+
+        code.append(System.lineSeparator());
+        return code;
+    }
+
+    private String generateIterator(Layer layer) {
+        String iteratorName = layer.getAttr("include.phase");
+        iteratorName = iteratorName.toLowerCase();
+        iteratorName = iteratorName + "_" + layer.getName() + "_" + "itr";
+
+        ST st = stGroup.getInstanceOf("iterator");
+
+        String prototxt = layer.getPrototxt();
+        prototxt = prototxt.replace("\r", "");
+        prototxt = prototxt.replace("\n", " \\\n");
+        prototxt = "'" + prototxt + "'";
+        prototxt = Utils.indent(prototxt, 1, true, 4);
+
+        st.add("iter_name", iteratorName);
+        st.add("prototxt", prototxt);
+
+        String dataName = "???";
+        if (layer.getTops().size() >= 1) {
+            dataName = layer.getTops().get(0);
+        } else {
+            System.err.println(String.format("Data layer %s doesn't have data", layer.getName()));
+        }
+        st.add("data_name", dataName);
+
+        String labelName = "???";
+        if (layer.getTops().size() >= 1) {
+            labelName = layer.getTops().get(1);
+        } else {
+            System.err.println(String.format("Data layer %s doesn't have label", layer.getName()));
+        }
+        st.add("label_name", labelName);
+
+        if (layer.hasAttr("data_param.num_examples")) {
+            st.add("num_examples", layer.getAttr("data_param.num_examples"));
+        }
+
+        return st.render();
+    }
+
+}
diff --git a/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/CreateModelListener.java b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/CreateModelListener.java
new file mode 100644
index 0000000..75800a1
--- /dev/null
+++ b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/CreateModelListener.java
@@ -0,0 +1,144 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file CreateModelListener.java
+ * \brief ANTLR listener that builds MLModel as the parser parses the Caffe prototxt
+ */
+
+package io.mxnet.caffetranslator;
+
+import lombok.Getter;
+import org.antlr.v4.runtime.Token;
+import org.antlr.v4.runtime.TokenStream;
+
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Stack;
+
+public class CreateModelListener extends CaffePrototxtBaseListener {
+
+    private final CaffePrototxtParser parser;
+    @Getter
+    private final MLModel mlModel;
+    private final Stack<String> keys;
+    private final ParserHelper parserHelper;
+
+    private Layer currentLayer;
+    private Map<String, String> currentParams;
+
+    public CreateModelListener(CaffePrototxtParser parser, MLModel mlModel) {
+        this.parser = parser;
+        this.mlModel = mlModel;
+        this.keys = new Stack<>();
+        this.currentParams = new HashMap<>();
+        this.parserHelper = new ParserHelper();
+    }
+
+    @Override
+    public void exitName(CaffePrototxtParser.NameContext ctx) {
+        String name = ctx.STRING().toString();
+        mlModel.setName(parserHelper.removeQuotes(name));
+    }
+
+    @Override
+    public void enterLayer(CaffePrototxtParser.LayerContext ctx) {
+        keys.clear();
+        currentLayer = new Layer();
+    }
+
+    @Override
+    public void exitLayer(CaffePrototxtParser.LayerContext ctx) {
+        TokenStream tokens = parser.getTokenStream();
+        String prototxt = getPrototxt(tokens, ctx.getStart().getTokenIndex(), ctx.getStop().getTokenIndex());
+
+        if (currentLayer.getTops().size() == 1) {
+            currentLayer.addAttr("top", currentLayer.getTops().get(0));
+        }
+
+        if (currentLayer.getBottoms().size() == 1) {
+            currentLayer.addAttr("bottom", currentLayer.getBottoms().get(0));
+        }
+
+        currentLayer.setPrototxt(prototxt);
+        mlModel.addLayer(currentLayer);
+    }
+
+    private String getPrototxt(TokenStream stream, int start, int end) {
+        StringBuilder prototxt = new StringBuilder();
+        for (int i = start; i <= end; i++) {
+            Token token = stream.get(i);
+            prototxt.append(token.getText());
+        }
+        String strPrototxt = prototxt.toString();
+        return strPrototxt.replaceAll(" +num_examples:.*\\s", "");
+    }
+
+    @Override
+    public void enterPair(CaffePrototxtParser.PairContext ctx) {
+        String key = ctx.getStart().getText();
+        keys.push(key);
+    }
+
+    @Override
+    public void exitPair(CaffePrototxtParser.PairContext ctx) {
+
+        if (getCurrentKey().equals("param")) {
+            currentLayer.getParams().add(currentParams);
+            currentParams = new HashMap<>();
+        }
+
+        keys.pop();
+    }
+
+    @Override
+    public void exitValueLeaf(CaffePrototxtParser.ValueLeafContext ctx) {
+        String value = ctx.getText();
+        value = parserHelper.removeQuotes(value);
+        processKeyValue(getCurrentKey(), value);
+    }
+
+    protected void processKeyValue(String key, String value) {
+        switch (key) {
+            case "name":
+                currentLayer.setName(value);
+                break;
+            case "top":
+                currentLayer.addTop(value);
+                return;
+            case "bottom":
+                currentLayer.addBottom(value);
+                return;
+        }
+
+        if (key.toLowerCase().startsWith("param.")) {
+            currentParams.put(key, value);
+        }
+
+        currentLayer.addAttr(key, value);
+    }
+
+    private String getCurrentKey() {
+        StringBuilder sb = new StringBuilder();
+        for (String s : keys) {
+            sb.append(s + ".");
+        }
+        return sb.substring(0, sb.length() - 1).toString();
+    }
+}
diff --git a/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/GenerationHelper.java b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/GenerationHelper.java
new file mode 100644
index 0000000..1cac546
--- /dev/null
+++ b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/GenerationHelper.java
@@ -0,0 +1,195 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file GenerationHelper.java
+ * \brief Helper class used by generators
+ */
+
+package io.mxnet.caffetranslator;
+
+import org.stringtemplate.v4.ST;
+import org.stringtemplate.v4.STErrorListener;
+import org.stringtemplate.v4.STGroup;
+import org.stringtemplate.v4.STGroupFile;
+import org.stringtemplate.v4.STRawGroupDir;
+import org.stringtemplate.v4.misc.STMessage;
+
+import java.util.ArrayList;
+import java.util.List;
+
+public class GenerationHelper {
+
+    protected final STGroup stGroupDir;
+
+    protected final STGroup stGroupFile;
+
+    private class SuppressSTErrorsListener implements STErrorListener {
+
+        @Override
+        public void compileTimeError(STMessage msg) {
+            // Do nothing
+        }
+
+        @Override
+        public void runTimeError(STMessage msg) {
+            // Do nothing
+        }
+
+        @Override
+        public void IOError(STMessage msg) {
+            throw new RuntimeException(msg.toString());
+        }
+
+        @Override
+        public void internalError(STMessage msg) {
+            throw new RuntimeException(msg.toString());
+        }
+    }
+
+    public GenerationHelper() {
+        this.stGroupDir = new STRawGroupDir("templates");
+        this.stGroupFile = new STGroupFile("templates/symbols.stg");
+
+        SuppressSTErrorsListener errListener = new SuppressSTErrorsListener();
+        stGroupDir.setListener(errListener);
+        stGroupFile.setListener(errListener);
+    }
+
+    public ST getTemplate(String name) {
+        ST st = stGroupDir.getInstanceOf(name);
+        if (st != null) {
+            return st;
+        }
+        return stGroupFile.getInstanceOf(name);
+    }
+
+    public String generateVar(String varName, String symName, String lr_mult, String wd_mult, String init, List<Integer> shape) {
+        ST st = getTemplate("var");
+        st.add("var", varName);
+        st.add("name", symName);
+
+        st.add("lr_mult", lr_mult);
+        st.add("wd_mult", wd_mult);
+        st.add("init", init);
+        st.add("shape", shape);
+
+        return st.render();
+    }
+
+    public String getInit(String fillerType, String fillerValue) {
+        if (fillerType == null && fillerValue == null) {
+            return null;
+        }
+
+        if (fillerType == null) {
+            fillerType = "constant";
+        }
+
+        if (fillerValue == null) {
+            fillerValue = "0";
+        }
+
+        String initializer;
+        switch (fillerType) {
+            case "xavier":
+                initializer = "mx.initializer.Xavier()";
+                break;
+            case "gaussian":
+                initializer = "mx.initializer.Normal()";
+                break;
+            case "constant":
+                initializer = String.format("mx.initializer.Constant(%s)", fillerValue);
+                break;
+            case "bilinear":
+                initializer = "mx.initializer.Bilinear()";
+                break;
+            default:
+                initializer = "UnknownInitializer";
+                System.err.println("Initializer " + fillerType + " not supported");
+                break;
+        }
+
+        return initializer;
+    }
+
+    public String getVarname(String name) {
+        StringBuilder sb = new StringBuilder(name);
+        for (int i = 0; i < sb.length(); i++) {
+            char ch = sb.charAt(i);
+            if (Character.isLetter(ch) || Character.isDigit(ch) || ch == '_') {
+                // do nothing
+            } else {
+                sb.replace(i, i + 1, "_");
+            }
+        }
+        return sb.toString();
+    }
+
+    public List<String> getVarNames(List<String> names) {
+        List<String> list = new ArrayList<>();
+        for (String name : names) {
+            list.add(getVarname(name));
+        }
+        return list;
+    }
+
+    public void fillNameDataAndVar(ST st, Layer layer) {
+        st.add("name", layer.getName());
+        st.add("data", getVarname(layer.getBottom()));
+        st.add("var", getVarname(layer.getTop()));
+    }
+
+    public void simpleFillTemplate(ST st, String name, Layer layer, String key, String defaultValue, String... altKeys) {
+        String value = layer.getAttr(key);
+
+        if (value == null) {
+            for (String altKey : altKeys) {
+                value = layer.getAttr(altKey);
+                if (value != null) {
+                    break;
+                }
+            }
+        }
+
+        if (value == null && defaultValue != null) {
+            value = defaultValue;
+        }
+
+        if (value == null) {
+            System.err.println(String.format("Layer %s does not contain attribute %s or alternates",
+                    layer.getName(), key));
+            value = "???";
+        }
+
+        st.add(name, value);
+    }
+
+    public GeneratorOutput makeGeneratorOutput(String code, int numLayersTranslated) {
+        return new GeneratorOutput(code, numLayersTranslated);
+    }
+
+    public String initializeParam(String varname, int childIndex, String initializer) {
+        StringBuilder out = new StringBuilder();
+        out.append(String.format("param_initializer.add_param(%s.get_children()[%d].name, %s)",
+                varname, childIndex, initializer));
+        out.append(System.lineSeparator());
+        return out.toString();
+    }
+}
diff --git a/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/GeneratorOutput.java b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/GeneratorOutput.java
new file mode 100644
index 0000000..fd27bb3
--- /dev/null
+++ b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/GeneratorOutput.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file GeneratorOutput.java
+ * \brief Output of each generator
+ */
+
+package io.mxnet.caffetranslator;
+
+public class GeneratorOutput {
+    public final String code;
+    public final int numLayersTranslated;
+
+    public GeneratorOutput(String code, int n) {
+        this.code = code;
+        this.numLayersTranslated = n;
+    }
+}
diff --git a/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/Launcher.java b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/Launcher.java
new file mode 100644
index 0000000..9fd3cbd
--- /dev/null
+++ b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/Launcher.java
@@ -0,0 +1,178 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file Launcher.java
+ * \brief Parses command line and invokes Converter
+ */
+
+package io.mxnet.caffetranslator;
+
+
+import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.CommandLineParser;
+import org.apache.commons.cli.DefaultParser;
+import org.apache.commons.cli.Option;
+import org.apache.commons.cli.Options;
+import org.apache.commons.cli.ParseException;
+
+import java.io.File;
+import java.io.FileNotFoundException;
+import java.io.PrintWriter;
+
+public class Launcher {
+
+    private String trainingPrototextPath, solverPrototextPath;
+    private String paramsFilePath;
+    private File outFile;
+
+    protected final String TRAINING_PROTOTXT = "training-prototxt";
+    protected final String SOLVER_PROTOTXT = "solver";
+    protected final String CUSTOM_DATA_LAYERS = "custom-data-layers";
+    protected final String OUTPUT_FILE = "output-file";
+    protected final String PARAMS_FILE = "params-file";
+    protected final String GRAPH_FILE = "graph-file";
+
+
+    public static void main(String[] args) {
+        Launcher launcher = new Launcher();
+        launcher.run(args);
+    }
+
+    public void run(String[] args) {
+        parseCommandLine(args);
+
+        Converter converter = new Converter(trainingPrototextPath, solverPrototextPath);
+        if (paramsFilePath != null) {
+            converter.setParamsFilePath(paramsFilePath);
+        }
+        String code = converter.generateMXNetCode();
+
+        writeToOutFile(code);
+        System.out.println("Translated code saved in " + outFile.getAbsolutePath());
+    }
+
+    private void writeToOutFile(String code) {
+        PrintWriter out;
+        try {
+            out = new PrintWriter(outFile);
+        } catch (FileNotFoundException e) {
+            System.err.println(String.format("Unable to open %s for writing", outFile.getAbsoluteFile()));
+            return;
+        }
+
+        out.print(code);
+        out.flush();
+    }
+
+    public void parseCommandLine(String[] args) {
+        CommandLineParser clParser = new DefaultParser();
+
+        Options options = new Options();
+
+        Option prototxtOption = Option.builder("t")
+                .longOpt(TRAINING_PROTOTXT)
+                .hasArg()
+                .desc("training/validation prototxt")
+                .build();
+        options.addOption(prototxtOption);
+
+        Option solverOption = Option.builder("s")
+                .longOpt(SOLVER_PROTOTXT)
+                .hasArg()
+                .desc("solver prototxt")
+                .build();
+        options.addOption(solverOption);
+
+        Option dataLayerOpt = Option.builder("c")
+                .longOpt(CUSTOM_DATA_LAYERS)
+                .hasArg()
+                .desc("Comma separated custom data layers")
+                .build();
+        options.addOption(dataLayerOpt);
+
+        Option outfileOpt = Option.builder("o")
+                .longOpt(OUTPUT_FILE)
+                .hasArg()
+                .desc("Output file")
+                .build();
+        options.addOption(outfileOpt);
+
+        Option paramsFileOpt = Option.builder("p")
+                .longOpt(PARAMS_FILE)
+                .hasArg()
+                .desc("Params file")
+                .build();
+        options.addOption(paramsFileOpt);
+
+        Option graphFileOpt = Option.builder("g")
+                .longOpt(GRAPH_FILE)
+                .hasArg()
+                .desc("Image file to visualize computation graph")
+                .build();
+        options.addOption(graphFileOpt);
+
+        CommandLine line = null;
+        try {
+            line = clParser.parse(options, args);
+        } catch (ParseException e) {
+            System.out.println("Exception parsing commandline:" + e.getMessage());
+            System.exit(1);
+        }
+
+        if ((trainingPrototextPath = getOption(line, TRAINING_PROTOTXT)) == null) {
+            bail("Command line argument " + TRAINING_PROTOTXT + " missing");
+        }
+
+        if ((solverPrototextPath = getOption(line, SOLVER_PROTOTXT)) == null) {
+            bail("Command line argument " + SOLVER_PROTOTXT + " missing");
+        }
+
+        String strOutFile = getOption(line, OUTPUT_FILE);
+        if (strOutFile == null) {
+            bail("Command line argument " + OUTPUT_FILE + " missing");
+        }
+        outFile = new File(strOutFile);
+
+        paramsFilePath = getOption(line, PARAMS_FILE);
+
+        String dataLayers;
+        Config config = Config.getInstance();
+        if ((dataLayers = getOption(line, CUSTOM_DATA_LAYERS)) != null) {
+            for (String name : dataLayers.split(",")) {
+                name = name.trim();
+                config.addCustomDataLayer(name);
+            }
+        }
+
+    }
+
+    private String getOption(CommandLine line, String argName) {
+        if (line.hasOption(argName)) {
+            return line.getOptionValue(argName);
+        } else {
+            return null;
+        }
+    }
+
+    private void bail(String reason) {
+        System.err.println(reason);
+        System.exit(1);
+    }
+}
diff --git a/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/Layer.java b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/Layer.java
new file mode 100644
index 0000000..dac8ff4
--- /dev/null
+++ b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/Layer.java
@@ -0,0 +1,141 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file Layer.java
+ * \brief Model for a layer
+ */
+
+package io.mxnet.caffetranslator;
+
+import lombok.Getter;
+import lombok.Setter;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+public class Layer {
+
+    @Getter
+    @Setter
+    private String name;
+
+    @Getter
+    @Setter
+    private int layerIndex;
+
+    @Getter
+    @Setter
+    private Kind kind;
+
+    @Getter
+    @Setter
+    private String prototxt;
+
+    @Getter
+    private final List<String> bottoms;
+
+    @Getter
+    private final List<String> tops;
+
+    @Setter
+    @Getter
+    private List<Map<String, String>> params;
+
+    @Setter
+    private Map<String, List<String>> attr;
+
+    public Layer() {
+        tops = new ArrayList<>();
+        bottoms = new ArrayList<>();
+        attr = new HashMap<>();
+        params = new ArrayList<>();
+    }
+
+    public Layer(int layerIndex) {
+        this();
+        this.layerIndex = layerIndex;
+    }
+
+    public void addAttr(String key, String value) {
+        List<String> list = attr.get(key);
+        if (list == null) {
+            list = new ArrayList<String>();
+            list.add(value);
+            attr.put(key, list);
+        } else {
+            list.add(value);
+        }
+    }
+
+    public String getAttr(String key) {
+        List<String> list = attr.get(key);
+        if (list == null) {
+            return null;
+        }
+
+        return list.get(0);
+    }
+
+    public String getAttr(String key, String defaultValue) {
+        String attr = getAttr(key);
+        return attr != null ? attr : defaultValue;
+    }
+
+    public boolean hasAttr(String key) {
+        return attr.containsKey(key);
+    }
+
+    public boolean attrEquals(String key, String value) {
+        if (!attr.containsKey(key)) {
+            return false;
+        }
+        return getAttr(key).equals(value);
+    }
+
+    public List<String> getAttrList(String key) {
+        return attr.get(key);
+    }
+
+    public void addTop(String top) {
+        tops.add(top);
+    }
+
+    public void addBottom(String bottom) {
+        bottoms.add(bottom);
+    }
+
+    public String getBottom() {
+        return bottoms.size() > 0 ? bottoms.get(0) : null;
+    }
+
+    public String getType() {
+        return attr.get("type").get(0);
+    }
+
+    public String getTop() {
+        return tops.size() > 0 ? tops.get(0) : null;
+    }
+
+    public enum Kind {
+        DATA, INTERMEDIATE, LOSS;
+    }
+}
diff --git a/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/MLModel.java b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/MLModel.java
new file mode 100644
index 0000000..08f0fe7
--- /dev/null
+++ b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/MLModel.java
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file MLModel.java
+ * \brief Models a ML model
+ */
+
+package io.mxnet.caffetranslator;
+
+import lombok.Getter;
+import lombok.Setter;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+public class MLModel {
+
+    public MLModel() {
+        layerList = new ArrayList<>();
+        layerLookup = new HashMap<>();
+        layerIndex = 0;
+    }
+
+    @Getter
+    @Setter
+    private String name;
+
+    @Getter
+    @Setter
+    private List<Layer> layerList;
+
+    private final Map<String, Map<String, Layer>> layerLookup;
+
+    private int layerIndex;
+
+    public void addLayer(Layer layer) {
+
+        layer.setLayerIndex(layerIndex++);
+        layerList.add(layer);
+
+        String name = layer.getName();
+        String includePhase = layer.getAttr("include.phase");
+        includePhase = (includePhase == null) ? "" : includePhase;
+
+        if (layerLookup.containsKey(name)) {
+            layerLookup.get(name).put(includePhase, layer);
+        } else {
+            HashMap map = new HashMap();
+            map.put(includePhase, layer);
+            layerLookup.put(name, map);
+        }
+
+        String type = layer.getAttr("type");
+        Config config = Config.getInstance();
+        if (type.equals("Data") || config.getCustomDataLayers().contains(type)) {
+            layer.setKind(Layer.Kind.DATA);
+        } else if (type.toLowerCase().endsWith("loss")) {
+            layer.setKind(Layer.Kind.LOSS);
+        } else {
+            layer.setKind(Layer.Kind.INTERMEDIATE);
+        }
+    }
+
+    public List<Layer> getDataLayers() {
+        List<Layer> ret = new ArrayList<>();
+
+        for (Layer layer : layerList) {
+            if (layer.getKind() == Layer.Kind.DATA) {
+                ret.add(layer);
+            }
+        }
+        return ret;
+    }
+
+    public List<Layer> getNonDataLayers() {
+        List<Layer> ret = new ArrayList<>();
+
+        for (Layer layer : layerList) {
+            if (layer.getKind() != Layer.Kind.DATA) {
+                ret.add(layer);
+            }
+        }
+        return ret;
+    }
+
+}
diff --git a/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/ParserHelper.java b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/ParserHelper.java
new file mode 100644
index 0000000..bf8a581
--- /dev/null
+++ b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/ParserHelper.java
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ParserHelper.java
+ * \brief Helpers required by the command line parser
+ */
+
+package io.mxnet.caffetranslator;
+
+public class ParserHelper {
+    public String removeQuotes(String arg) {
+        boolean doubleQuoteStr = (arg.startsWith("\"") && arg.endsWith("\""));
+        boolean singleQuoteStr = (arg.startsWith("'") && arg.endsWith("'"));
+        if ((singleQuoteStr | doubleQuoteStr) && arg.length() > 2) {
+            arg = arg.substring(1, arg.length() - 1);
+        }
+        return arg;
+    }
+}
diff --git a/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/Solver.java b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/Solver.java
new file mode 100644
index 0000000..ec4c812
--- /dev/null
+++ b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/Solver.java
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file Solver.java
+ * \brief Model for the Caffe solver prototxt
+ */
+
+package io.mxnet.caffetranslator;
+
+import org.antlr.v4.runtime.CharStream;
+import org.antlr.v4.runtime.CharStreams;
+import org.antlr.v4.runtime.CommonTokenStream;
+
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.nio.charset.StandardCharsets;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+public class Solver {
+
+    private boolean parseDone;
+    private Map<String, List<String>> properties;
+    private final String solverPath;
+
+    public Solver(String solverPath) {
+        this.solverPath = solverPath;
+        properties = new HashMap<>();
+    }
+
+    public boolean parsePrototxt() {
+        CharStream cs = null;
+        try {
+            FileInputStream fis = new FileInputStream(new File(solverPath));
+            cs = CharStreams.fromStream(fis, StandardCharsets.UTF_8);
+        } catch (IOException e) {
+            System.err.println("Unable to read prototxt " + solverPath);
+            return false;
+        }
+
+        CaffePrototxtLexer lexer = new CaffePrototxtLexer(cs);
+        CommonTokenStream tokens = new CommonTokenStream(lexer);
+        CaffePrototxtParser parser = new CaffePrototxtParser(tokens);
+
+        SolverListener solverListener = new SolverListener();
+        parser.addParseListener(solverListener);
+        parser.solver();
+
+        properties = solverListener.getProperties();
+
+        parseDone = true;
+        return true;
+    }
+
+    public String getProperty(String key) {
+        List<String> list = getProperties(key);
+        if (list == null) {
+            return null;
+        }
+        return getProperties(key).get(0);
+    }
+
+    public List<String> getProperties(String key) {
+        if (!parseDone) {
+            parsePrototxt();
+        }
+
+        return properties.get(key);
+    }
+
+    public String getProperty(String key, String defaultValue) {
+        String value = getProperty(key);
+        if (value == null) {
+            return defaultValue;
+        } else {
+            return value;
+        }
+    }
+}
diff --git a/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/SolverListener.java b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/SolverListener.java
new file mode 100644
index 0000000..18b7fe1
--- /dev/null
+++ b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/SolverListener.java
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file SolverListener.java
+ * \brief ANTLR listener that builds the Solver instance as the solver prototxt is parsed
+ */
+
+package io.mxnet.caffetranslator;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+public class SolverListener extends CaffePrototxtBaseListener {
+
+    private final Map<String, List<String>> properties;
+    private final ParserHelper parserHelper;
+
+    public SolverListener() {
+        properties = new HashMap<>();
+        parserHelper = new ParserHelper();
+    }
+
+    public Map<String, List<String>> getProperties() {
+        return properties;
+    }
+
+    @Override
+    public void exitPair(CaffePrototxtParser.PairContext ctx) {
+        String key = ctx.ID().getText();
+        String value = ctx.value().getText();
+        value = parserHelper.removeQuotes(value);
+
+        if (properties.get(key) == null) {
+            properties.put(key, new ArrayList<>());
+        }
+
+        properties.get(key).add(value);
+    }
+}
diff --git a/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/SymbolGenerator.java b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/SymbolGenerator.java
new file mode 100644
index 0000000..7a21aed
--- /dev/null
+++ b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/SymbolGenerator.java
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file SymbolGenerator.java
+ * \brief Interface that every layer generator implements
+ */
+
+package io.mxnet.caffetranslator;
+
+public interface SymbolGenerator {
+    public GeneratorOutput generate(Layer layer, MLModel model);
+}
diff --git a/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/SymbolGeneratorFactory.java b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/SymbolGeneratorFactory.java
new file mode 100644
index 0000000..5dea77e
--- /dev/null
+++ b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/SymbolGeneratorFactory.java
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file SymbolGeneratorFactory.java
+ * \brief A factory used to create a generator for a given layer type
+ */
+
+package io.mxnet.caffetranslator;
+
+import java.util.HashMap;
+import java.util.Map;
+
+public class SymbolGeneratorFactory {
+
+    private static SymbolGeneratorFactory instance = new SymbolGeneratorFactory();
+    Map<String, SymbolGenerator> generators;
+
+    public static SymbolGeneratorFactory getInstance() {
+        return instance;
+    }
+
+    private SymbolGeneratorFactory() {
+        if (instance != null) {
+            throw new IllegalStateException("SymbolGeneratorFactory already instantiated");
+        }
+        generators = new HashMap<>();
+    }
+
+    public SymbolGenerator getGenerator(String symbolType) {
+        return generators.get(symbolType);
+    }
+
+    public void addGenerator(String symbolType, SymbolGenerator generator) {
+        generators.put(symbolType, generator);
+    }
+}
diff --git a/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/Utils.java b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/Utils.java
new file mode 100644
index 0000000..0b006b1
--- /dev/null
+++ b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/Utils.java
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file Utils.java
+ * \brief General util functions
+ */
+
+package io.mxnet.caffetranslator;
+
+import java.util.Collections;
+
+public class Utils {
+    public static String indent(String str, int level, boolean useSpaces, int numSpaces) {
+        String prefix;
+        if (!useSpaces) {
+            prefix = String.join("", Collections.nCopies(level, "\t"));
+        } else {
+            String spaces = String.join("", Collections.nCopies(numSpaces, " "));
+            prefix = String.join("", Collections.nCopies(level, spaces));
+        }
+
+        String indented = str.replaceAll("(?m)^", prefix);
+        return indented;
+    }
+}
diff --git a/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/generators/AccuracyMetricsGenerator.java b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/generators/AccuracyMetricsGenerator.java
new file mode 100644
index 0000000..d1f185f
--- /dev/null
+++ b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/generators/AccuracyMetricsGenerator.java
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file AccuracyMetricsGenerator.java
+ * \brief Generate Accuracy metric
+ */
+
+package io.mxnet.caffetranslator.generators;
+
+import io.mxnet.caffetranslator.GenerationHelper;
+import io.mxnet.caffetranslator.Layer;
+import io.mxnet.caffetranslator.MLModel;
+import org.stringtemplate.v4.ST;
+
+import java.util.HashMap;
+import java.util.Map;
+
+public class AccuracyMetricsGenerator {
+
+    private final Map<String, String> map;
+    private final GenerationHelper gh;
+
+    public AccuracyMetricsGenerator() {
+        map = new HashMap<>();
+        gh = new GenerationHelper();
+    }
+
+    public String generate(MLModel model) {
+        StringBuilder out = new StringBuilder();
+        generateMap(model);
+
+        for (Layer layer : model.getLayerList()) {
+            if (layer.getType().equals("Accuracy")) {
+                ST st;
+                if (layer.getAttr("accuracy_param.top_k", "1").equals("1")) {
+                    st = gh.getTemplate("accuracy");
+                } else {
+                    st = gh.getTemplate("top_k_accuracy");
+                    st.add("k", layer.getAttr("accuracy_param.top_k"));
+                }
+
+                st.add("var", gh.getVarname(layer.getTop()));
+                String outputName = map.get(layer.getBottoms().get(0)) + "_output";
+                st.add("output_name", outputName);
+                st.add("label_name", layer.getBottoms().get(1));
+                st.add("name", layer.getName());
+
+                out.append(st.render());
+                out.append(System.lineSeparator());
+            }
+        }
+
+        return out.toString();
+    }
+
+    private void generateMap(MLModel model) {
+        for (Layer layer : model.getLayerList()) {
+            // If this is not SoftmaxWithLoss, move on
+            if (!layer.getType().equals("SoftmaxWithLoss")) {
+                continue;
+            }
+
+            map.put(layer.getBottoms().get(0), layer.getName());
+        }
+    }
+}
diff --git a/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/generators/BaseGenerator.java b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/generators/BaseGenerator.java
new file mode 100644
index 0000000..0d7fc05
--- /dev/null
+++ b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/generators/BaseGenerator.java
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file BaseGenerator.java
+ * \brief Base class for all source generators
+ */
+
+package io.mxnet.caffetranslator.generators;
+
+import io.mxnet.caffetranslator.GenerationHelper;
+import io.mxnet.caffetranslator.SymbolGenerator;
+import org.stringtemplate.v4.ST;
+
+import java.util.List;
+
+public abstract class BaseGenerator implements SymbolGenerator {
+
+    protected final GenerationHelper gh;
+
+    public BaseGenerator() {
+        gh = new GenerationHelper();
+    }
+
+    protected ST getTemplate(String name) {
+        return gh.getTemplate(name);
+    }
+
+    protected String generateVar(String varName, String symName, String lr_mult, String wd_mult, String init, List<Integer> shape) {
+        ST st = getTemplate("var");
+        st.add("var", varName);
+        st.add("name", symName);
+
+        st.add("lr_mult", (lr_mult == null) ? "None" : lr_mult);
+        st.add("wd_mult", (wd_mult == null) ? "None" : wd_mult);
+        st.add("init", (init == null) ? "None" : init);
+        if (shape != null) {
+            st.add("shape", shape);
+        }
+
+        return st.render();
+    }
+
+}
diff --git a/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/generators/BatchNormGenerator.java b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/generators/BatchNormGenerator.java
new file mode 100644
index 0000000..503bd3e
--- /dev/null
+++ b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/generators/BatchNormGenerator.java
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file BatchNormGenerator.java
+ * \brief Generate BatchNorm layer
+ */
+
+package io.mxnet.caffetranslator.generators;
+
+import io.mxnet.caffetranslator.GeneratorOutput;
+import io.mxnet.caffetranslator.Layer;
+import io.mxnet.caffetranslator.MLModel;
+import org.stringtemplate.v4.ST;
+
+public class BatchNormGenerator extends BaseGenerator {
+    @Override
+    public GeneratorOutput generate(Layer layer, MLModel model) {
+        ST st = getTemplate("batchnorm");
+
+        gh.fillNameDataAndVar(st, layer);
+
+        if (layer.attrEquals("batch_norm_param.use_global_stats", "true")) {
+            st.add("use_global_stats", true);
+        }
+
+        int layerIndex = layer.getLayerIndex();
+        Layer nextLayer = model.getLayerList().get(layerIndex + 1);
+
+        boolean nextLayerIsScale = false;
+        if (nextLayer.getType().toLowerCase().equals("scale")) {
+            String axis = nextLayer.getAttr("ScaleParameter.axis", "1");
+            String numAxis = nextLayer.getAttr("ScaleParameter.num_axes", "1");
+            if (axis.equals("1") && numAxis.equals("1")) {
+                String biasTerm = nextLayer.getAttr("ScaleParameter.bias_term", "false");
+                if (biasTerm.toLowerCase().equals("false")) {
+                    nextLayerIsScale = true;
+                }
+            }
+        }
+
+        if (!nextLayerIsScale) {
+            st.add("fix_beta", true);
+            st.add("fix_gamma", true);
+        }
+
+        return new GeneratorOutput(st.render(), nextLayerIsScale ? 2 : 1);
+    }
+}
diff --git a/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/generators/ConcatGenerator.java b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/generators/ConcatGenerator.java
new file mode 100644
index 0000000..c9a5794
--- /dev/null
+++ b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/generators/ConcatGenerator.java
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ConcatGenerator.java
+ * \brief Generate Concat layer
+ */
+
+package io.mxnet.caffetranslator.generators;
+
+import io.mxnet.caffetranslator.GeneratorOutput;
+import io.mxnet.caffetranslator.Layer;
+import io.mxnet.caffetranslator.MLModel;
+import org.stringtemplate.v4.ST;
+
+public class ConcatGenerator extends BaseGenerator {
+
+    @Override
+    public GeneratorOutput generate(Layer layer, MLModel model) {
+        ST st = getTemplate("concat");
+
+        st.add("name", layer.getName());
+        st.add("var", gh.getVarname(layer.getTop()));
+        st.add("data", gh.getVarNames(layer.getBottoms()));
+
+        String dim = layer.getAttr("concat_param.axis");
+        if (dim != null) {
+            st.add("dim", dim);
+        }
+
+        return new GeneratorOutput(st.render(), 1);
+    }
+}
diff --git a/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/generators/ConvolutionGenerator.java b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/generators/ConvolutionGenerator.java
new file mode 100644
index 0000000..eda59e5
--- /dev/null
+++ b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/generators/ConvolutionGenerator.java
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ConvolutionGenerator.java
+ * \brief Generate Convolution layer
+ */
+
+package io.mxnet.caffetranslator.generators;
+
+import io.mxnet.caffetranslator.GeneratorOutput;
+import io.mxnet.caffetranslator.Layer;
+import io.mxnet.caffetranslator.MLModel;
+import org.stringtemplate.v4.ST;
+
+import java.util.Map;
+
+public class ConvolutionGenerator extends BaseGenerator {
+
+    @Override
+    public GeneratorOutput generate(Layer layer, MLModel model) {
+        StringBuilder out = new StringBuilder();
+
+        ST st = getTemplate("convolution");
+        gh.fillNameDataAndVar(st, layer);
+
+        // Set kernel size
+        gh.simpleFillTemplate(st, "kernel_h", layer, "convolution_param.kernel_h", null,
+                "convolution_param.kernel_size");
+        gh.simpleFillTemplate(st, "kernel_w", layer, "convolution_param.kernel_w", null,
+                "convolution_param.kernel_size");
+
+        // Set stride
+        gh.simpleFillTemplate(st, "stride_h", layer, "convolution_param.stride_h", "1",
+                "convolution_param.stride");
+        gh.simpleFillTemplate(st, "stride_w", layer, "convolution_param.stride_w", "1",
+                "convolution_param.stride");
+
+        // Set padding
+        gh.simpleFillTemplate(st, "pad_h", layer, "convolution_param.pad_h", "0",
+                "convolution_param.pad");
+        gh.simpleFillTemplate(st, "pad_w", layer, "convolution_param.pad_w", "0",
+                "convolution_param.pad");
+
+        // Use bias?
+        if (layer.attrEquals("convolution_param.bias_term", "false")) {
+            st.add("no_bias", "NoBiasPlease"); //value doesn't matter
+        }
+
+        // Number of channels in output
+        gh.simpleFillTemplate(st, "num_filter", layer, "convolution_param.num_output", null);
+
+        String weightInit = gh.getInit(
+                layer.getAttr("convolution_param.weight_filler.type"),
+                layer.getAttr("convolution_param.weight_filler.value"));
+
+        String biasInit = gh.getInit(
+                layer.getAttr("convolution_param.bias_filler.type"),
+                layer.getAttr("convolution_param.bias_filler.value"));
+
+        if (weightInit != null || layer.getParams().size() >= 1) {
+            Map<String, String> param = layer.getParams().get(0);
+            out.append(
+                    generateVar("weight", layer.getName() + "_weight",
+                            param.get("param.lr_mult"), param.get("param.decay_mult"),
+                            weightInit, null)
+            );
+            st.add("weight", "weight");
+        }
+
+        if (biasInit != null || layer.getParams().size() >= 2) {
+            Map<String, String> param = layer.getParams().get(1);
+            out.append(
+                    generateVar("bias", layer.getName() + "_bias",
+                            param.get("param.lr_mult"), param.get("param.decay_mult"),
+                            biasInit, null)
+            );
+            st.add("bias", "bias");
+        }
+
+        out.append(st.render());
+        return new GeneratorOutput(out.toString(), 1);
+    }
+}
+
diff --git a/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/generators/DeconvolutionGenerator.java b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/generators/DeconvolutionGenerator.java
new file mode 100644
index 0000000..5e79fed
--- /dev/null
+++ b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/generators/DeconvolutionGenerator.java
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file DeconvolutionGenerator.java
+ * \brief Generate Deconvolution layer
+ */
+
+package io.mxnet.caffetranslator.generators;
+
+import io.mxnet.caffetranslator.GeneratorOutput;
+import io.mxnet.caffetranslator.Layer;
+import io.mxnet.caffetranslator.MLModel;
+import org.stringtemplate.v4.ST;
+
+import java.util.Map;
+
+public class DeconvolutionGenerator extends BaseGenerator {
+
+    @Override
+    public GeneratorOutput generate(Layer layer, MLModel model) {
+        StringBuilder out = new StringBuilder();
+        ST st = getTemplate("deconvolution");
+        gh.fillNameDataAndVar(st, layer);
+
+        // Set kernel size
+        gh.simpleFillTemplate(st, "kernel_h", layer, "convolution_param.kernel_h", null,
+                "convolution_param.kernel_size");
+        gh.simpleFillTemplate(st, "kernel_w", layer, "convolution_param.kernel_w", null,
+                "convolution_param.kernel_size");
+
+        // Set stride
+        gh.simpleFillTemplate(st, "stride_h", layer, "convolution_param.stride_h", "1",
+                "convolution_param.stride");
+        gh.simpleFillTemplate(st, "stride_w", layer, "convolution_param.stride_w", "1",
+                "convolution_param.stride");
+
+        // Set padding
+        gh.simpleFillTemplate(st, "pad_h", layer, "convolution_param.pad_h", "0",
+                "convolution_param.pad");
+        gh.simpleFillTemplate(st, "pad_w", layer, "convolution_param.pad_w", "0",
+                "convolution_param.pad");
+
+        // Use bias?
+        if (layer.attrEquals("convolution_param.bias_term", "false")) {
+            st.add("no_bias", "NoBiasPlease");
+        }
+
+        // Number of channels in output
+        gh.simpleFillTemplate(st, "num_filter", layer, "convolution_param.num_output", null);
+
+        // Group
+        gh.simpleFillTemplate(st, "group", layer, "convolution_param.group", "PP_REMOVE");
+
+
+        // Custom weight and bias if needed
+        String weightInit = gh.getInit(
+                layer.getAttr("convolution_param.weight_filler.type"),
+                layer.getAttr("convolution_param.weight_filler.value"));
+
+        String biasInit = gh.getInit(
+                layer.getAttr("convolution_param.bias_filler.type"),
+                layer.getAttr("convolution_param.bias_filler.value"));
+
+        if (weightInit != null || layer.getParams().size() >= 1) {
+            Map<String, String> param = layer.getParams().get(0);
+            out.append(
+                    generateVar("weight", layer.getName() + "_weight",
+                            param.get("param.lr_mult"), param.get("param.decay_mult"),
+                            weightInit, null)
+            );
+            st.add("weight", "weight");
+        }
+
+        if (biasInit != null || layer.getParams().size() >= 2) {
+            Map<String, String> param = layer.getParams().get(1);
+            out.append(
+                    generateVar("bias", layer.getName() + "_bias",
+                            param.get("param.lr_mult"), param.get("param.decay_mult"),
+                            biasInit, null)
+            );
+        }
+
+        out.append(st.render());
+        return new GeneratorOutput(out.toString(), 1);
+    }
+}
diff --git a/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/generators/DropoutGenerator.java b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/generators/DropoutGenerator.java
new file mode 100644
index 0000000..198f3b0
--- /dev/null
+++ b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/generators/DropoutGenerator.java
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file DropoutGenerator.java
+ * \brief Generate Dropout layer
+ */
+
+package io.mxnet.caffetranslator.generators;
+
+import io.mxnet.caffetranslator.GeneratorOutput;
+import io.mxnet.caffetranslator.Layer;
+import io.mxnet.caffetranslator.MLModel;
+import org.stringtemplate.v4.ST;
+
+public class DropoutGenerator extends BaseGenerator {
+
+    @Override
+    public GeneratorOutput generate(Layer layer, MLModel model) {
+        ST st = getTemplate("dropout");
+        gh.fillNameDataAndVar(st, layer);
+
+        gh.simpleFillTemplate(st, "prob", layer, "dropout_param.dropout_ratio", "0.5");
+
+        return new GeneratorOutput(st.render(), 1);
+    }
+}
diff --git a/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/generators/EltwiseGenerator.java b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/generators/EltwiseGenerator.java
new file mode 100644
index 0000000..edd5765
--- /dev/null
+++ b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/generators/EltwiseGenerator.java
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file EltwiseGenerator.java
+ * \brief Generate Eltwise layer
+ */
+
+package io.mxnet.caffetranslator.generators;
+
+import io.mxnet.caffetranslator.GeneratorOutput;
+import io.mxnet.caffetranslator.Layer;
+import io.mxnet.caffetranslator.MLModel;
+import org.stringtemplate.v4.ST;
+
+import java.util.List;
+
+public class EltwiseGenerator extends BaseGenerator {
+
+    @Override
+    public GeneratorOutput generate(Layer layer, MLModel model) {
+        String operation = layer.getAttr("eltwise_param.operation");
+        if (operation == null) {
+            operation = "SUM";
+        }
+
+        ST st;
+        switch (operation) {
+            case "SUM":
+                st = getTemplate("add");
+                break;
+            case "PROD":
+                st = getTemplate("mul");
+                break;
+            case "MAX":
+                st = getTemplate("maximum");
+                break;
+            default:
+                String error = "Unrecognized operation " + operation + " in Eltwise" + System.lineSeparator();
+                System.err.print(error);
+                return new GeneratorOutput(error, 1);
+        }
+
+        st.add("name", layer.getName());
+        st.add("var", gh.getVarname(layer.getTop()));
+
+        List<String> data = gh.getVarNames(layer.getBottoms());
+        st.add("data1", data.get(0));
+        st.add("data2", data.get(1));
+
+        return new GeneratorOutput(st.render(), 1);
+    }
+}
diff --git a/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/generators/FCGenerator.java b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/generators/FCGenerator.java
new file mode 100644
index 0000000..753b874
--- /dev/null
+++ b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/generators/FCGenerator.java
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file FCGenerator.java
+ * \brief Generate fully connected layer
+ */
+
+package io.mxnet.caffetranslator.generators;
+
+import io.mxnet.caffetranslator.GeneratorOutput;
+import io.mxnet.caffetranslator.Layer;
+import io.mxnet.caffetranslator.MLModel;
+import org.stringtemplate.v4.ST;
+
+import java.util.Map;
+
+public class FCGenerator extends BaseGenerator {
+
+    @Override
+    public GeneratorOutput generate(Layer layer, MLModel model) {
+        StringBuilder out = new StringBuilder();
+
+        ST st = getTemplate("fc");
+
+        gh.fillNameDataAndVar(st, layer);
+
+        gh.simpleFillTemplate(st, "num", layer, "inner_product_param.num_output", null);
+
+        if (layer.attrEquals("inner_product_param.bias_term", "false")) {
+            st.add("no_bias", "NoBiasPlease"); //value doesn't matter
+        }
+
+        String weightInit = gh.getInit(
+                layer.getAttr("inner_product_param.weight_filler.type"),
+                layer.getAttr("inner_product_param.weight_filler.value"));
+
+        String biasInit = gh.getInit(
+                layer.getAttr("inner_product_param.bias_filler.type"),
+                layer.getAttr("inner_product_param.bias_filler.value"));
+
+        if (weightInit != null || layer.getParams().size() >= 1) {
+            Map<String, String> param = layer.getParams().get(0);
+            out.append(
+                    generateVar("weight", layer.getName() + "_weight",
+                            param.get("param.lr_mult"), param.get("param.decay_mult"),
+                            weightInit, null)
+            );
+            st.add("weight", "weight");
+        }
+
+        if (biasInit != null || layer.getParams().size() >= 2) {
+            Map<String, String> param = layer.getParams().get(1);
+            out.append(
+                    generateVar("bias", layer.getName() + "_bias",
+                            param.get("param.lr_mult"), param.get("param.decay_mult"),
+                            biasInit, null)
+            );
+            st.add("bias", "bias");
+        }
+
+        out.append(st.render());
+
+        return gh.makeGeneratorOutput(out.toString(), 1);
+    }
+}
diff --git a/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/generators/FlattenGenerator.java b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/generators/FlattenGenerator.java
new file mode 100644
index 0000000..5eb8a13
--- /dev/null
+++ b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/generators/FlattenGenerator.java
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file FlattenGenerator.java
+ * \brief Generate flatten layer
+ */
+
+package io.mxnet.caffetranslator.generators;
+
+import io.mxnet.caffetranslator.GeneratorOutput;
+import io.mxnet.caffetranslator.Layer;
+import io.mxnet.caffetranslator.MLModel;
+import org.stringtemplate.v4.ST;
+
+public class FlattenGenerator extends BaseGenerator {
+
+    @Override
+    public GeneratorOutput generate(Layer layer, MLModel model) {
+
+        ST st = getTemplate("flatten");
+        gh.fillNameDataAndVar(st, layer);
+
+        String axis = layer.getAttr("flatten_param.axis");
+        if (axis != null && Integer.valueOf(axis) != 1) {
+            String error = "Axis other that 1 is not supported for flatten" + System.lineSeparator();
+            System.err.println(error);
+            return new GeneratorOutput(error, 1);
+        }
+
+        return new GeneratorOutput(st.render(), 1);
+    }
+}
diff --git a/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/generators/PermuteGenerator.java b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/generators/PermuteGenerator.java
new file mode 100644
index 0000000..f7383bc
--- /dev/null
+++ b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/generators/PermuteGenerator.java
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file PermuteGenerator.java
+ * \brief Generate Permute layer
+ */
+
+package io.mxnet.caffetranslator.generators;
+
+import io.mxnet.caffetranslator.GeneratorOutput;
+import io.mxnet.caffetranslator.Layer;
+import io.mxnet.caffetranslator.MLModel;
+import org.stringtemplate.v4.ST;
+
+import java.util.List;
+
+public class PermuteGenerator extends BaseGenerator {
+
+    @Override
+    public GeneratorOutput generate(Layer layer, MLModel model) {
+        ST st = getTemplate("permute");
+        gh.fillNameDataAndVar(st, layer);
+
+        List<String> axes = layer.getAttrList("permute_param.order");
+        if (axes != null) {
+            st.add("axes", axes);
+        }
+
+        return new GeneratorOutput(st.render(), 1);
+    }
+}
diff --git a/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/generators/PluginIntLayerGenerator.java b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/generators/PluginIntLayerGenerator.java
new file mode 100644
index 0000000..048b537
--- /dev/null
+++ b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/generators/PluginIntLayerGenerator.java
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file PluginIntLayerGenerator.java
+ * \brief Generate a layer using Caffe Plugin
+ */
+
+package io.mxnet.caffetranslator.generators;
+
+import io.mxnet.caffetranslator.GeneratorOutput;
+import io.mxnet.caffetranslator.Layer;
+import io.mxnet.caffetranslator.MLModel;
+import org.stringtemplate.v4.ST;
+
+public class PluginIntLayerGenerator extends BaseGenerator {
+
+    private PluginLayerHelper helper;
+
+
+    public PluginIntLayerGenerator() {
+        super();
+        helper = new PluginLayerHelper();
+    }
+
+    @Override
+    public GeneratorOutput generate(Layer layer, MLModel model) {
+        return generate(layer, model, 0);
+    }
+
+    public GeneratorOutput generate(Layer layer, MLModel model, int num_weight) {
+        ST st = getTemplate("CaffePluginIntLayer");
+
+        st.add("name", layer.getName());
+
+        if (layer.getBottoms().size() != 1) {
+            st.add("num_data", layer.getBottoms().size());
+        }
+        if (layer.getTops().size() != 1) {
+            st.add("num_out", layer.getTops().size());
+        }
+        if (num_weight != 0) {
+            st.add("num_weight", num_weight);
+        }
+
+        String dataList = helper.getDataList(layer);
+        st.add("data", dataList);
+
+        // Set prototxt
+        String prototxt = helper.makeOneLine(layer.getPrototxt());
+        st.add("prototxt", prototxt);
+
+        // Handle multiple outputs
+        if (layer.getTops().size() > 1) {
+            st.add("tops", layer.getTops());
+            st.add("var", "out");
+        } else if (layer.getTops().size() == 1) {
+            st.add("var", gh.getVarname(layer.getTop()));
+        }
+
+        return new GeneratorOutput(st.render(), 1);
+    }
+
+}
diff --git a/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/generators/PluginLayerHelper.java b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/generators/PluginLayerHelper.java
new file mode 100644
index 0000000..3b8506c
--- /dev/null
+++ b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/generators/PluginLayerHelper.java
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file PluginLayerHelper.java
+ * \brief Helper class to generate layers using Caffe Plugin
+ */
+
+package io.mxnet.caffetranslator.generators;
+
+import io.mxnet.caffetranslator.GenerationHelper;
+import io.mxnet.caffetranslator.Layer;
+
+public class PluginLayerHelper {
+
+    private final GenerationHelper gh;
+
+    public PluginLayerHelper() {
+        gh = new GenerationHelper();
+    }
+
+    public String getDataList(Layer layer) {
+        StringBuilder sb = new StringBuilder();
+        int index = 0;
+
+        if (layer.getBottoms().size() == 0) {
+            return null;
+        }
+
+        for (String bottom : layer.getBottoms()) {
+            sb.append("data_" + index + "=" + gh.getVarname(bottom) + ", ");
+            index++;
+        }
+        if (sb.length() > 0) {
+            sb.setLength(sb.length() - 2);
+        }
+        return sb.toString();
+    }
+
+    public String makeOneLine(String prototxt) {
+        prototxt = prototxt.replaceAll("\n", "").replaceAll("\r", "");
+        prototxt = prototxt.replaceAll("'", "\'");
+        prototxt = prototxt.replaceAll("\\s{2,}", " ").trim();
+        return prototxt;
+    }
+
+}
diff --git a/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/generators/PluginLossGenerator.java b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/generators/PluginLossGenerator.java
new file mode 100644
index 0000000..5e98151
--- /dev/null
+++ b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/generators/PluginLossGenerator.java
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file PluginLossGenerator.java
+ * \brief Generate loss layer using Caffe Plugin
+ */
+
+package io.mxnet.caffetranslator.generators;
+
+import io.mxnet.caffetranslator.GeneratorOutput;
+import io.mxnet.caffetranslator.Layer;
+import io.mxnet.caffetranslator.MLModel;
+import org.stringtemplate.v4.ST;
+
+public class PluginLossGenerator extends BaseGenerator {
+
+    private final PluginLayerHelper helper;
+
+    public PluginLossGenerator() {
+        super();
+        helper = new PluginLayerHelper();
+    }
+
+    @Override
+    public GeneratorOutput generate(Layer layer, MLModel model) {
+        ST st = getTemplate("CaffePluginLossLayer");
+
+        st.add("name", layer.getName());
+
+        // Handle data
+        if (layer.getBottoms().size() != 1) {
+            st.add("num_data", layer.getBottoms().size());
+        }
+        String dataList = helper.getDataList(layer);
+        st.add("data", dataList);
+
+        // Set prototxt
+        String prototxt = helper.makeOneLine(layer.getPrototxt());
+        st.add("prototxt", prototxt);
+
+        // Handle multiple outputs
+        if (layer.getTops().size() > 1) {
+            st.add("tops", layer.getTops());
+            st.add("var", "out");
+        } else if (layer.getTops().size() == 1) {
+            st.add("var", layer.getTop());
+        }
+
+        return new GeneratorOutput(st.render(), 1);
+    }
+
+}
diff --git a/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/generators/PoolingGenerator.java b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/generators/PoolingGenerator.java
new file mode 100644
index 0000000..ad91f58
--- /dev/null
+++ b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/generators/PoolingGenerator.java
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file PoolingGenerator.java
+ * \brief Generate Pooling layer
+ */
+
+package io.mxnet.caffetranslator.generators;
+
+import io.mxnet.caffetranslator.GeneratorOutput;
+import io.mxnet.caffetranslator.Layer;
+import io.mxnet.caffetranslator.MLModel;
+import org.stringtemplate.v4.ST;
+
+public class PoolingGenerator extends BaseGenerator {
+
+    @Override
+    public GeneratorOutput generate(Layer layer, MLModel model) {
+        ST st = getTemplate("pooling");
+
+        gh.fillNameDataAndVar(st, layer);
+
+        boolean globalPooling = layer.getAttr("pooling_param.global_pooling", "false")
+                .toLowerCase().equals("true");
+
+        if (globalPooling) {
+            st.add("global_pool", "True");
+            st.add("kernel_h", "1");
+            st.add("kernel_w", "1");
+        } else {
+            // Set kernel size
+            gh.simpleFillTemplate(st, "kernel_h", layer, "pooling_param.kernel_h", null,
+                    "pooling_param.kernel_size");
+            gh.simpleFillTemplate(st, "kernel_w", layer, "pooling_param.kernel_w", null,
+                    "pooling_param.kernel_size");
+        }
+
+        // Set stride
+        gh.simpleFillTemplate(st, "stride_h", layer, "pooling_param.stride_h", "1",
+                "pooling_param.stride");
+        gh.simpleFillTemplate(st, "stride_w", layer, "pooling_param.stride_w", "1",
+                "pooling_param.stride");
+
+        // Set padding
+        gh.simpleFillTemplate(st, "pad_h", layer, "pooling_param.pad_h", "0",
+                "pooling_param.pad");
+        gh.simpleFillTemplate(st, "pad_w", layer, "pooling_param.pad_w", "0",
+                "pooling_param.pad");
+
+        // Set type
+        String poolType = layer.getAttr("pooling_param.pool");
+        switch (poolType) {
+            case "MAX":
+                st.add("type", "max");
+                break;
+            case "AVE":
+                st.remove("type");
+                st.add("type", "avg");
+                break;
+            case "STOCHASTIC":
+                System.err.println("Stochastic pooling type not supported.");
+                st.add("type", "???");
+                break;
+        }
+
+        return new GeneratorOutput(st.render(), 1);
+    }
+
+}
diff --git a/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/generators/PowerGenerator.java b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/generators/PowerGenerator.java
new file mode 100644
index 0000000..d467650
--- /dev/null
+++ b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/generators/PowerGenerator.java
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file PowerGenerator.java
+ * \brief Generate Power layer
+ */
+
+package io.mxnet.caffetranslator.generators;
+
+import io.mxnet.caffetranslator.GeneratorOutput;
+import io.mxnet.caffetranslator.Layer;
+import io.mxnet.caffetranslator.MLModel;
+import org.stringtemplate.v4.ST;
+
+public class PowerGenerator extends BaseGenerator {
+
+    @Override
+    public GeneratorOutput generate(Layer layer, MLModel model) {
+        ST st = getTemplate("power");
+
+        String power = layer.getAttr("power_param.power", "1");
+        String scale = layer.getAttr("power_param.scale", "1");
+        String shift = layer.getAttr("power_param.shift", "0");
+
+        st.add("var", gh.getVarname(layer.getTop()));
+        st.add("data", gh.getVarname(layer.getBottom()));
+
+        st.add("power", power);
+        st.add("scale", scale);
+        st.add("shift", shift);
+
+        return new GeneratorOutput(st.render(), 1);
+    }
+}
diff --git a/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/generators/ReluGenerator.java b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/generators/ReluGenerator.java
new file mode 100644
index 0000000..37ac9a8
--- /dev/null
+++ b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/generators/ReluGenerator.java
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ReluGenerator.java
+ * \brief Generate Relu layer
+ */
+
+package io.mxnet.caffetranslator.generators;
+
+import io.mxnet.caffetranslator.GeneratorOutput;
+import io.mxnet.caffetranslator.Layer;
+import io.mxnet.caffetranslator.MLModel;
+import org.stringtemplate.v4.ST;
+
+public class ReluGenerator extends BaseGenerator {
+
+    @Override
+    public GeneratorOutput generate(Layer layer, MLModel model) {
+        ST st = getTemplate("activation");
+
+        gh.fillNameDataAndVar(st, layer);
+        st.add("type", "relu");
+
+        return new GeneratorOutput(st.render(), 1);
+    }
+
+}
diff --git a/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/generators/ScaleGenerator.java b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/generators/ScaleGenerator.java
new file mode 100644
index 0000000..fc919e3
--- /dev/null
+++ b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/generators/ScaleGenerator.java
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ScaleGenerator.java
+ * \brief Generate Scale layer
+ */
+
+
+package io.mxnet.caffetranslator.generators;
+
+import io.mxnet.caffetranslator.GeneratorOutput;
+import io.mxnet.caffetranslator.Layer;
+import io.mxnet.caffetranslator.MLModel;
+
+public class ScaleGenerator extends BaseGenerator {
+
+    @Override
+    public GeneratorOutput generate(Layer layer, MLModel model) {
+        PluginIntLayerGenerator generator = new PluginIntLayerGenerator();
+
+        boolean use_bias = layer.getAttr("scale_param.bias_term", "false").toLowerCase().equals("true");
+
+        StringBuilder out = new StringBuilder();
+
+        if (use_bias) {
+            out.append(generator.generate(layer, model, 2).code);
+        } else {
+            out.append(generator.generate(layer, model, 1).code);
+        }
+
+        String fillerType = layer.getAttr("filler.type");
+        String fillerValue = layer.getAttr("filler.value");
+        if (fillerType == null && fillerValue == null) {
+            fillerValue = "1";
+        }
+        out.append(gh.initializeParam(gh.getVarname(layer.getTop()), 1, gh.getInit(fillerType, fillerValue)));
+
+        if (use_bias) {
+            fillerType = layer.getAttr("bias_filler.type");
+            fillerValue = layer.getAttr("bias_filler.value");
+            if (fillerType == null && fillerValue == null) {
+                fillerValue = "0";
+            }
+            out.append(gh.initializeParam(gh.getVarname(layer.getTop()), 2, gh.getInit(fillerType, fillerValue)));
+        }
+
+        return gh.makeGeneratorOutput(out.toString(), 1);
+    }
+}
diff --git a/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/generators/SoftmaxOutputGenerator.java b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/generators/SoftmaxOutputGenerator.java
new file mode 100644
index 0000000..a017e4f
--- /dev/null
+++ b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/generators/SoftmaxOutputGenerator.java
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file SoftmaxOutputGenerator.java
+ * \brief Generate SoftmaxOutput layer
+ */
+
+package io.mxnet.caffetranslator.generators;
+
+import io.mxnet.caffetranslator.GeneratorOutput;
+import io.mxnet.caffetranslator.Layer;
+import io.mxnet.caffetranslator.MLModel;
+import org.stringtemplate.v4.ST;
+
+public class SoftmaxOutputGenerator extends BaseGenerator {
+    @Override
+    public GeneratorOutput generate(Layer layer, MLModel model) {
+        ST st = getTemplate("softmaxoutput");
+        gh.fillNameDataAndVar(st, layer);
+
+        st.add("label", gh.getVarname(layer.getBottoms().get(1)));
+        st.add("label_name", layer.getBottoms().get(1));
+
+        return new GeneratorOutput(st.render(), 1);
+    }
+}
diff --git a/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/misc/CollectStats.java b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/misc/CollectStats.java
new file mode 100644
index 0000000..e38c2d0
--- /dev/null
+++ b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/misc/CollectStats.java
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file CollectStats.java
+ * \brief Print all unique layers used in a prototxt along with the parameters used in each layer type.
+ */
+
+package io.mxnet.caffetranslator.misc;
+
+import io.mxnet.caffetranslator.CaffePrototxtLexer;
+import io.mxnet.caffetranslator.CaffePrototxtParser;
+import org.antlr.v4.runtime.CharStream;
+import org.antlr.v4.runtime.CharStreams;
+import org.antlr.v4.runtime.CommonTokenStream;
+
+import java.io.File;
+import java.io.FileInputStream;
+import java.nio.charset.StandardCharsets;
+import java.util.Iterator;
+import java.util.Map;
+import java.util.Set;
+
+public class CollectStats {
+
+    public static void main(String arsg[]) {
+        String filePath = "path";
+
+        CharStream cs = null;
+        try {
+            FileInputStream fis = new FileInputStream(new File(filePath));
+            cs = CharStreams.fromStream(fis, StandardCharsets.UTF_8);
+        } catch (Exception e) {
+            e.printStackTrace();
+        }
+
+        CaffePrototxtLexer lexer = new CaffePrototxtLexer(cs);
+        CommonTokenStream tokens = new CommonTokenStream(lexer);
+        CaffePrototxtParser parser = new CaffePrototxtParser(tokens);
+
+        StatsListener statsListener = new StatsListener();
+        parser.addParseListener(statsListener);
+        parser.prototxt();
+
+        Map<String, Set<String>> attrMap = statsListener.getAttrMap();
+
+        Iterator it = attrMap.entrySet().iterator();
+        while (it.hasNext()) {
+            Map.Entry<String, Set<String>> pair = (Map.Entry) it.next();
+            System.out.println(pair.getKey() + ":");
+            for (String value : pair.getValue()) {
+                System.out.println("    " + value);
+            }
+        }
+    }
+
+}
diff --git a/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/misc/StatsListener.java b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/misc/StatsListener.java
new file mode 100644
index 0000000..71f9662
--- /dev/null
+++ b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/misc/StatsListener.java
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file StatsListener.java
+ * \brief ANTLR listener to collect stats used by CollectStats.java
+ */
+
+package io.mxnet.caffetranslator.misc;
+
+import io.mxnet.caffetranslator.CaffePrototxtBaseListener;
+import io.mxnet.caffetranslator.CaffePrototxtParser;
+import io.mxnet.caffetranslator.ParserHelper;
+import lombok.Getter;
+
+import java.util.Map;
+import java.util.Set;
+import java.util.Stack;
+import java.util.TreeMap;
+import java.util.TreeSet;
+
+public class StatsListener extends CaffePrototxtBaseListener {
+
+    private final Stack<String> keys;
+    @Getter
+    private final Map<String, Set<String>> attrMap;
+    private final ParserHelper parserHelper;
+
+    private String layerType;
+    private Set<String> curAttr;
+
+    public StatsListener() {
+        attrMap = new TreeMap<>();
+        keys = new Stack<>();
+        parserHelper = new ParserHelper();
+    }
+
+    @Override
+    public void enterLayer(CaffePrototxtParser.LayerContext ctx) {
+        keys.clear();
+        curAttr = new TreeSet<>();
+    }
+
+    @Override
+    public void exitLayer(CaffePrototxtParser.LayerContext ctx) {
+        if (!attrMap.containsKey(layerType)) {
+            attrMap.put(layerType, new TreeSet<>());
+        }
+        Set<String> set = attrMap.get(layerType);
+        set.addAll(curAttr);
+    }
+
+    @Override
+    public void exitValueLeaf(CaffePrototxtParser.ValueLeafContext ctx) {
+        String value = ctx.getText();
+        value = parserHelper.removeQuotes(value);
+        processKeyValue(getCurrentKey(), value);
+    }
+
+    private void processKeyValue(String key, String value) {
+        if (key.equals("type")) {
+            layerType = value;
+        } else {
+            curAttr.add(key);
+        }
+    }
+
+    @Override
+    public void enterPair(CaffePrototxtParser.PairContext ctx) {
+        String key = ctx.getStart().getText();
+        keys.push(key);
+    }
+
+    @Override
+    public void exitPair(CaffePrototxtParser.PairContext ctx) {
+        keys.pop();
+    }
+
+    private String getCurrentKey() {
+        StringBuilder sb = new StringBuilder();
+        for (String s : keys) {
+            sb.append(s + ".");
+        }
+        return sb.substring(0, sb.length() - 1).toString();
+    }
+
+}
diff --git a/tools/caffe_translator/src/main/resources/templates/accuracy.st b/tools/caffe_translator/src/main/resources/templates/accuracy.st
new file mode 100644
index 0000000..f741def
--- /dev/null
+++ b/tools/caffe_translator/src/main/resources/templates/accuracy.st
@@ -0,0 +1,2 @@
+<var> = mx.metric.Accuracy(output_names=['<output_name>'], label_names=['<label_name>'], name='<name>')
+test_metrics.add(<var>)
diff --git a/tools/caffe_translator/src/main/resources/templates/activation.st b/tools/caffe_translator/src/main/resources/templates/activation.st
new file mode 100644
index 0000000..5a9c37b
--- /dev/null
+++ b/tools/caffe_translator/src/main/resources/templates/activation.st
@@ -0,0 +1 @@
+<var> = mx.symbol.Activation(data=<data>, act_type='<type>', name='<name>')
diff --git a/tools/caffe_translator/src/main/resources/templates/add.st b/tools/caffe_translator/src/main/resources/templates/add.st
new file mode 100644
index 0000000..ca9428f
--- /dev/null
+++ b/tools/caffe_translator/src/main/resources/templates/add.st
@@ -0,0 +1 @@
+<var> = <data1> + <data2>
diff --git a/tools/caffe_translator/src/main/resources/templates/batchnorm.st b/tools/caffe_translator/src/main/resources/templates/batchnorm.st
new file mode 100644
index 0000000..c043c70
--- /dev/null
+++ b/tools/caffe_translator/src/main/resources/templates/batchnorm.st
@@ -0,0 +1,14 @@
+<if(fix_beta)>
+<var>_beta = mx.sym.BlockGrad(mx.sym.Variable("<name>_beta", init=mx.init.Constant(0)))
+<endif>
+<var> = mx.symbol.BatchNorm(data=<data>,
+<if(fix_beta)>
+    beta=<var>_beta,
+<endif>
+<if(fix_gamma)>
+    fix_gamma=True,
+<endif>
+<if(use_global_stats)>
+    use_global_stats=True,
+<endif>
+    name='<name>')
diff --git a/tools/caffe_translator/src/main/resources/templates/concat.st b/tools/caffe_translator/src/main/resources/templates/concat.st
new file mode 100644
index 0000000..75ffa3c
--- /dev/null
+++ b/tools/caffe_translator/src/main/resources/templates/concat.st
@@ -0,0 +1 @@
+<var> = mx.sym.concat(<data;separator=", "><if(dim)>, dim=<dim><endif>, name='<name>');
diff --git a/tools/caffe_translator/src/main/resources/templates/convolution.st b/tools/caffe_translator/src/main/resources/templates/convolution.st
new file mode 100644
index 0000000..c4bdd51
--- /dev/null
+++ b/tools/caffe_translator/src/main/resources/templates/convolution.st
@@ -0,0 +1,9 @@
+<var> = mx.sym.Convolution(data=<data>,
+    <if(weight)>weight=<weight>,<endif>
+    <if(bias)>bias=<bias>,<endif>
+    kernel=(<kernel_h>,<kernel_w>),
+    stride=(<stride_h>,<stride_w>),
+    pad=(<pad_h>,<pad_w>),
+    num_filter=<num_filter>,
+    <if(no_bias)>no_bias=True,<endif>
+    name='<name>')
diff --git a/tools/caffe_translator/src/main/resources/templates/deconvolution.st b/tools/caffe_translator/src/main/resources/templates/deconvolution.st
new file mode 100644
index 0000000..5b63f56
--- /dev/null
+++ b/tools/caffe_translator/src/main/resources/templates/deconvolution.st
@@ -0,0 +1,10 @@
+<var> = mx.sym.Deconvolution(data=<data>,
+    <if(use_weight)>weight=weight,<endif>
+    <if(use_bias)>bias=bias,<endif>
+    kernel=(<kernel_h>,<kernel_w>),
+    stride=(<stride_h>,<stride_w>),
+    pad=(<pad_h>,<pad_w>),
+    num_filter=<num_filter>,
+    num_group=<group>,
+    <if(no_bias)>no_bias=True,<endif>
+    name='<name>')
diff --git a/tools/caffe_translator/src/main/resources/templates/dropout.st b/tools/caffe_translator/src/main/resources/templates/dropout.st
new file mode 100644
index 0000000..9791c09
--- /dev/null
+++ b/tools/caffe_translator/src/main/resources/templates/dropout.st
@@ -0,0 +1 @@
+<var> = mx.sym.Dropout(data=<data>, p=<prob>, name='<name>')
diff --git a/tools/caffe_translator/src/main/resources/templates/fc.st b/tools/caffe_translator/src/main/resources/templates/fc.st
new file mode 100644
index 0000000..22365b3
--- /dev/null
+++ b/tools/caffe_translator/src/main/resources/templates/fc.st
@@ -0,0 +1 @@
+<var> = mx.symbol.FullyConnected(data=<data>, <if(weight)>weight=<weight>, <endif><if(bias)>bias=<bias>, <endif>num_hidden=<num>, <if(no_bias)>no_bias=True, <endif>name='<name>')
diff --git a/tools/caffe_translator/src/main/resources/templates/flatten.st b/tools/caffe_translator/src/main/resources/templates/flatten.st
new file mode 100644
index 0000000..8434335
--- /dev/null
+++ b/tools/caffe_translator/src/main/resources/templates/flatten.st
@@ -0,0 +1 @@
+<var> = mx.sym.flatten(data=<data>, name='<name>')
diff --git a/tools/caffe_translator/src/main/resources/templates/group.st b/tools/caffe_translator/src/main/resources/templates/group.st
new file mode 100644
index 0000000..33e312f
--- /dev/null
+++ b/tools/caffe_translator/src/main/resources/templates/group.st
@@ -0,0 +1 @@
+<var> = mx.sym.Group([<symbols;separator=", ">]);
diff --git a/tools/caffe_translator/src/main/resources/templates/imports.st b/tools/caffe_translator/src/main/resources/templates/imports.st
new file mode 100644
index 0000000..b37bd33
--- /dev/null
+++ b/tools/caffe_translator/src/main/resources/templates/imports.st
@@ -0,0 +1,7 @@
+from __future__ import division
+import copy
+import logging
+import math
+import sys
+
+import mxnet as mx
diff --git a/tools/caffe_translator/src/main/resources/templates/init_params.st b/tools/caffe_translator/src/main/resources/templates/init_params.st
new file mode 100644
index 0000000..3a277b6
--- /dev/null
+++ b/tools/caffe_translator/src/main/resources/templates/init_params.st
@@ -0,0 +1,7 @@
+<if(params_file)>
+arg_params, aux_params = load_params('<params_file>')
+module.init_params(initializer=mx.init.Xavier(), arg_params=arg_params, aux_params=aux_params,
+                   allow_missing=True)
+<else>
+module.init_params(initializer=mx.init.Xavier())
+<endif>
diff --git a/tools/caffe_translator/src/main/resources/templates/iterator.st b/tools/caffe_translator/src/main/resources/templates/iterator.st
new file mode 100644
index 0000000..5bc2a9d
--- /dev/null
+++ b/tools/caffe_translator/src/main/resources/templates/iterator.st
@@ -0,0 +1,10 @@
+<iter_name> = mx.io.CaffeDataIter(
+    prototxt =
+<prototxt>,
+    data_name='<data_name>',
+    label_name='<label_name>',
+<if(num_examples)>
+    num_examples=<num_examples>,
+<endif>
+    flat = False
+)
diff --git a/tools/caffe_translator/src/main/resources/templates/logging.st b/tools/caffe_translator/src/main/resources/templates/logging.st
new file mode 100644
index 0000000..73785e5
--- /dev/null
+++ b/tools/caffe_translator/src/main/resources/templates/logging.st
@@ -0,0 +1,11 @@
+def get_logger(name):
+    formatter = logging.Formatter(fmt='%(asctime)s %(levelname)-8s %(message)s',
+                                  datefmt='%Y-%m-%d %H:%M:%S')
+    stdout_handler = logging.StreamHandler(stream=sys.stdout)
+    stdout_handler.setFormatter(formatter)
+    logger = logging.getLogger(name)
+    logger.setLevel(logging.DEBUG)
+    logger.addHandler(stdout_handler)
+    return logger
+
+logger = get_logger("<name>")
diff --git a/tools/caffe_translator/src/main/resources/templates/lrn.st b/tools/caffe_translator/src/main/resources/templates/lrn.st
new file mode 100644
index 0000000..ec003c1
--- /dev/null
+++ b/tools/caffe_translator/src/main/resources/templates/lrn.st
@@ -0,0 +1 @@
+<var> = mx.sym.LRN(data=<data>, alpha=<alpha>, beta=<beta>, knorm=<knorm>, nsize=<nsize>, name=<name>)
diff --git a/tools/caffe_translator/src/main/resources/templates/lrpolicy_exp.st b/tools/caffe_translator/src/main/resources/templates/lrpolicy_exp.st
new file mode 100644
index 0000000..43afca2
--- /dev/null
+++ b/tools/caffe_translator/src/main/resources/templates/lrpolicy_exp.st
@@ -0,0 +1,3 @@
+lr = optimizer_params['learning_rate']
+lr *= gamma
+optimizer_params['learning_rate'] = lr
diff --git a/tools/caffe_translator/src/main/resources/templates/lrpolicy_inv.st b/tools/caffe_translator/src/main/resources/templates/lrpolicy_inv.st
new file mode 100644
index 0000000..5da8aa6
--- /dev/null
+++ b/tools/caffe_translator/src/main/resources/templates/lrpolicy_inv.st
@@ -0,0 +1,3 @@
+lr = optimizer_params['learning_rate']
+lr = base_lr * math.pow((1 + gamma * batch_num), -power)
+optimizer_params['learning_rate'] = lr
diff --git a/tools/caffe_translator/src/main/resources/templates/lrpolicy_multistep.st b/tools/caffe_translator/src/main/resources/templates/lrpolicy_multistep.st
new file mode 100644
index 0000000..fe09301
--- /dev/null
+++ b/tools/caffe_translator/src/main/resources/templates/lrpolicy_multistep.st
@@ -0,0 +1,5 @@
+lr_update_steps = [<steps;separator=", ">]
+if(batch_num in lr_update_steps):
+    lr = optimizer_params['learning_rate']
+    lr *= gamma
+    optimizer_params['learning_rate'] = lr
diff --git a/tools/caffe_translator/src/main/resources/templates/lrpolicy_poly.st b/tools/caffe_translator/src/main/resources/templates/lrpolicy_poly.st
new file mode 100644
index 0000000..e43fd78
--- /dev/null
+++ b/tools/caffe_translator/src/main/resources/templates/lrpolicy_poly.st
@@ -0,0 +1,3 @@
+lr = optimizer_params['learning_rate']
+lr = math.pow(base_lr * (1 - batch_num/max_iter), power)
+optimizer_params['learning_rate'] = lr
diff --git a/tools/caffe_translator/src/main/resources/templates/lrpolicy_sigmoid.st b/tools/caffe_translator/src/main/resources/templates/lrpolicy_sigmoid.st
new file mode 100644
index 0000000..33ba055
--- /dev/null
+++ b/tools/caffe_translator/src/main/resources/templates/lrpolicy_sigmoid.st
@@ -0,0 +1,3 @@
+lr = optimizer_params['learning_rate']
+lr = base_lr * ( 1/(1 + math.exp(-gamma * (batch_num - stepsize))))
+optimizer_params['learning_rate'] = lr
diff --git a/tools/caffe_translator/src/main/resources/templates/lrpolicy_step.st b/tools/caffe_translator/src/main/resources/templates/lrpolicy_step.st
new file mode 100644
index 0000000..04468ae
--- /dev/null
+++ b/tools/caffe_translator/src/main/resources/templates/lrpolicy_step.st
@@ -0,0 +1,4 @@
+if(batch_num % stepsize == 0):
+    lr = optimizer_params['learning_rate']
+    lr *= gamma
+    optimizer_params['learning_rate'] = lr
diff --git a/tools/caffe_translator/src/main/resources/templates/maxium.st b/tools/caffe_translator/src/main/resources/templates/maxium.st
new file mode 100644
index 0000000..d9431dd
--- /dev/null
+++ b/tools/caffe_translator/src/main/resources/templates/maxium.st
@@ -0,0 +1 @@
+<var> = mx.sym.maximum(<data1>, <data2>)
diff --git a/tools/caffe_translator/src/main/resources/templates/metrics_classes.st b/tools/caffe_translator/src/main/resources/templates/metrics_classes.st
new file mode 100644
index 0000000..e8323fb
--- /dev/null
+++ b/tools/caffe_translator/src/main/resources/templates/metrics_classes.st
@@ -0,0 +1,87 @@
+class TrainMetrics():
+
+    metric_map = {}
+
+    def __init__(self, display=None, average_loss=1):
+        self.average_loss = average_loss
+        self.display = display
+
+
+    def process(self, batch_num, module, label):
+        if self.display == None:
+            return
+
+        if self.average_loss == 1:
+            if batch_num % self.display == 0:
+                self.update_metrics(module, label, reset=True)
+                self.print_metrics(batch_num)
+        else:
+            # If I'll have to print metrics 'average_loss' iterations from now,
+            # append a metric so I can start updating that.
+            if((batch_num + self.average_loss) % self.display == 0):
+                self.append_one()
+
+            # If I'm less than 'average_loss' iteration away from a display step,
+            # update the metrics.
+            if((batch_num + self.average_loss) % self.display \< self.average_loss):
+                self.update_metrics(module, label)
+
+            # If I'm at a display step, print the metrics.
+            if(batch_num % self.display == 0):
+                self.print_metrics(batch_num, remove_heads=True)
+
+    def add(self, metric):
+        self.metric_map[metric.name] = [metric]
+
+    def append_one(self):
+        for key, lst in self.metric_map.iteritems():
+            last_element = lst[-1]
+            new_element = copy.deepcopy(last_element)
+            new_element.reset()
+            lst.append(new_element)
+
+    def update_metrics(self, module, label, reset=False):
+        for key, lst in self.metric_map.iteritems():
+            for metric in lst:
+                if reset:
+                    metric.reset()
+                module.update_metric(metric, label)
+
+    def print_metrics(self, batch_num, remove_heads=False):
+
+        total_loss = 0
+        for key, lst in self.metric_map.iteritems():
+                total_loss += lst[0].get()[1]
+
+        logger.info("Iteration %d, loss = %f" % (batch_num, total_loss))
+
+        for key, lst in self.metric_map.iteritems():
+            if remove_heads:
+                metric = lst.pop(0)
+            else:
+                metric = lst[0]
+
+            logger.info("    %s" % metric)
+
+
+class TestMetrics():
+
+    metrics = []
+
+    def add(self, metric):
+        self.metrics.append(metric)
+
+    def score_and_print(self, module, itr, num_batch):
+        for metric in self.metrics:
+            metric.reset()
+            module.score(itr, metric, num_batch=num_batch)
+            logger.info("    %s" % metric)
+
+<if(display)>
+display = <display>
+<endif>
+<if(average_loss)>
+average_loss = <average_loss>
+<endif>
+train_metrics = TrainMetrics(<if(display)>display=display<endif><if(average_loss)>, average_loss=average_loss<endif>)
+test_metrics = TestMetrics()
diff --git a/tools/caffe_translator/src/main/resources/templates/mul.st b/tools/caffe_translator/src/main/resources/templates/mul.st
new file mode 100644
index 0000000..411a407
--- /dev/null
+++ b/tools/caffe_translator/src/main/resources/templates/mul.st
@@ -0,0 +1 @@
+<var> = <data1> * (<data2>)
diff --git a/tools/caffe_translator/src/main/resources/templates/opt_default.st b/tools/caffe_translator/src/main/resources/templates/opt_default.st
new file mode 100644
index 0000000..e5a72ac
--- /dev/null
+++ b/tools/caffe_translator/src/main/resources/templates/opt_default.st
@@ -0,0 +1,15 @@
+<if(lr)>
+base_lr = <lr>
+<endif>
+<if(momentum)>
+momentum = <momentum>
+<endif>
+<if(wd)>
+wd = <wd>
+<endif>
+<if(epsilon)>
+epsilon = <epsilon>
+<endif>
+
+optimizer_params={'learning_rate':base_lr <if(momentum)>, 'momentum':momentum<endif><if(wd)>, 'wd':wd<endif><if(epsilon)>, 'epsilon':epsilon<endif>}
+module.init_optimizer(optimizer='<opt_name>', optimizer_params=optimizer_params)
diff --git a/tools/caffe_translator/src/main/resources/templates/opt_sgd.st b/tools/caffe_translator/src/main/resources/templates/opt_sgd.st
new file mode 100644
index 0000000..8a24e05
--- /dev/null
+++ b/tools/caffe_translator/src/main/resources/templates/opt_sgd.st
@@ -0,0 +1,12 @@
+<if(lr)>
+base_lr = <lr>
+<endif>
+<if(momentum)>
+momentum = <momentum>
+<endif>
+<if(wd)>
+wd = <wd>
+<endif>
+
+optimizer_params={'learning_rate':base_lr <if(momentum)>, 'momentum':momentum<endif><if(wd)>, 'wd':wd<endif>}
+module.init_optimizer(optimizer='<opt_name>', optimizer_params=optimizer_params)
diff --git a/tools/caffe_translator/src/main/resources/templates/param_initializer.st b/tools/caffe_translator/src/main/resources/templates/param_initializer.st
new file mode 100644
index 0000000..b496fc3
--- /dev/null
+++ b/tools/caffe_translator/src/main/resources/templates/param_initializer.st
@@ -0,0 +1,12 @@
+class ParamInitializer():
+    lst_patterns = []
+    lst_initializers = []
+
+    def add_param(self, pattern, initializer):
+        self.lst_patterns.append(pattern)
+        self.lst_initializers.append(initializer)
+
+    def get_initializer(self, default_initializer):
+        self.lst_patterns.append(".*")
+        self.lst_initializers.append(default_initializer)
+        return mx.initializer.Mixed(self.lst_patterns, self.lst_initializers)
diff --git a/tools/caffe_translator/src/main/resources/templates/params_loader.st b/tools/caffe_translator/src/main/resources/templates/params_loader.st
new file mode 100644
index 0000000..22efec4
--- /dev/null
+++ b/tools/caffe_translator/src/main/resources/templates/params_loader.st
@@ -0,0 +1,13 @@
+def load_params(params_file):
+    save_dict = mx.nd.load(params_file)
+    arg_params = {}
+    aux_params = {}
+    for k, value in save_dict.items():
+        arg_type, name = k.split(':', 1)
+        if arg_type == 'arg':
+            arg_params[name] = value
+        elif arg_type == 'aux':
+            aux_params[name] = value
+        else:
+            raise ValueError("Invalid param file " + fname)
+    return arg_params, aux_params
diff --git a/tools/caffe_translator/src/main/resources/templates/permute.st b/tools/caffe_translator/src/main/resources/templates/permute.st
new file mode 100644
index 0000000..2b06a76
--- /dev/null
+++ b/tools/caffe_translator/src/main/resources/templates/permute.st
@@ -0,0 +1 @@
+<var> = mx.sym.transpose(data=<data>, axes=(<axes;separator=", ">), name='<name>')
diff --git a/tools/caffe_translator/src/main/resources/templates/pooling.st b/tools/caffe_translator/src/main/resources/templates/pooling.st
new file mode 100644
index 0000000..5389754
--- /dev/null
+++ b/tools/caffe_translator/src/main/resources/templates/pooling.st
@@ -0,0 +1,16 @@
+<var> = mx.symbol.Pooling(data=<data>,
+    pool_type='<type>',
+<if(global_pool)>
+    global_pool=<global_pool>,
+<endif>
+<if(kernel_h)>
+    kernel=(<kernel_h>,<kernel_w>),
+<endif>
+<if(stride_h)>
+    stride=(<stride_h>,<stride_w>),
+<endif>
+<if(pad_h)>
+    pad=(<pad_h>,<pad_w>),
+<endif>
+    pooling_convention='full',
+    name='<name>')
diff --git a/tools/caffe_translator/src/main/resources/templates/power.st b/tools/caffe_translator/src/main/resources/templates/power.st
new file mode 100644
index 0000000..a512a67
--- /dev/null
+++ b/tools/caffe_translator/src/main/resources/templates/power.st
@@ -0,0 +1 @@
+<var> = (<shift> + (<scale> * <data>)) ** <power>
diff --git a/tools/caffe_translator/src/main/resources/templates/runner.st b/tools/caffe_translator/src/main/resources/templates/runner.st
new file mode 100644
index 0000000..6df9671
--- /dev/null
+++ b/tools/caffe_translator/src/main/resources/templates/runner.st
@@ -0,0 +1,57 @@
+ctx = <ctx>
+
+module = mx.mod.Module(symbol=<loss>, context=ctx, data_names=[<data_names;separator=", ">], label_names=[<label_names;separator=", ">])
+module.bind(data_shapes=<train_data_itr>.provide_data,
+            label_shapes=<train_data_itr>.provide_label)
+
+<init_params>
+
+<init_optimizer>
+
+epoch = 1
+batch_num = 1
+
+max_iter = <max_iter>
+snapshot = <snapshot>
+test_interval = <test_interval>
+test_iter = <test_iter>
+
+while batch_num \<= max_iter:
+    <train_data_itr>.reset()
+
+    for batch in <train_data_itr>:
+        module.forward(data_batch=batch, is_train=True)
+        module.backward()
+        module.update()
+
+        train_metrics.process(batch_num, module, batch.label)
+
+        if(batch_num % test_interval == 0):
+            logger.info("Iteration %d, Testing net" % batch_num)
+            test_metrics.score_and_print(module, <test_data_itr>, num_batch=test_iter)
+
+        if(batch_num % snapshot == 0):
+            # write snapshot
+            module.save_checkpoint(prefix="<snapshot_prefix>", epoch=batch_num, save_optimizer_states=True)
+
+        batch_num += 1
+
+        if batch_num > max_iter:
+            break
+
+<if(stepsize)>
+        stepsize = <stepsize>
+<endif>
+<if(gamma)>
+        gamma = <gamma>
+<endif>
+<if(power)>
+        power = <power>
+<endif>
+<lr_update>
+
+    epoch += 1
+
+
+logger.info("Training done. Saving model to <snapshot_prefix>")
+module.save_checkpoint(prefix="<snapshot_prefix>", epoch=batch_num, save_optimizer_states=True)
diff --git a/tools/caffe_translator/src/main/resources/templates/softmaxoutput.st b/tools/caffe_translator/src/main/resources/templates/softmaxoutput.st
new file mode 100644
index 0000000..bc63891
--- /dev/null
+++ b/tools/caffe_translator/src/main/resources/templates/softmaxoutput.st
@@ -0,0 +1,3 @@
+<var> = mx.sym.SoftmaxOutput(data=<data>, label=<label>, name='<name>')
+<var>_metric = mx.metric.CrossEntropy(output_names=['<name>_output'], label_names=['<label_name>'], name='<name>/metric')
+train_metrics.add(<var>_metric)
diff --git a/tools/caffe_translator/src/main/resources/templates/symbols.stg b/tools/caffe_translator/src/main/resources/templates/symbols.stg
new file mode 100644
index 0000000..fda9125
--- /dev/null
+++ b/tools/caffe_translator/src/main/resources/templates/symbols.stg
@@ -0,0 +1,7 @@
+CaffePluginIntLayer(var, tops, num_data, num_weight, num_out, data, prototxt, name) ::= "<var> = mx.symbol.CaffeOp(<if(data)><data>, <endif><if(num_data)>num_data=<num_data>, <endif><if(num_out)>num_out=<num_out>, <endif><if(num_weight)>num_weight=<num_weight>, <endif>prototxt='<prototxt>', name='<name>')
+<if(tops)><tops:{top|<top_assign(top, var, i0)>};separator=\"\n\"> <endif>"
+
+CaffePluginLossLayer(var, tops, num_data, data, prototxt, name) ::= "<var> = mx.symbol.CaffeLoss(<data><if(num_data)>, num_data=<num_data><endif>, prototxt='<prototxt>', name='<name>')
+<if(tops)><tops:{top|<top_assign(top, var, i0)>};separator=\"\n\"> <endif>"
+
+top_assign(top, var, index) ::= "<top> = <var>[<index>]"
diff --git a/tools/caffe_translator/src/main/resources/templates/top_k_accuracy.st b/tools/caffe_translator/src/main/resources/templates/top_k_accuracy.st
new file mode 100644
index 0000000..de93ee9
--- /dev/null
+++ b/tools/caffe_translator/src/main/resources/templates/top_k_accuracy.st
@@ -0,0 +1,2 @@
+<var> = mx.metric.TopKAccuracy(top_k=<k>, output_names=['<output_name>'], label_names=['<label_name>'], name='<name>')
+test_metrics.add(<var>)
diff --git a/tools/caffe_translator/src/main/resources/templates/var.st b/tools/caffe_translator/src/main/resources/templates/var.st
new file mode 100644
index 0000000..e850b689
--- /dev/null
+++ b/tools/caffe_translator/src/main/resources/templates/var.st
@@ -0,0 +1 @@
+<var> = mx.sym.Variable('<name>'<if(lr_mult)>, lr_mult=<lr_mult><endif><if(wd_mult)>, wd_mult=<wd_mult><endif><if(init)>, init=<init><endif><if(shape)>, shape=(<shape;separator=", ">)<endif>)

-- 
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].