You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by in...@apache.org on 2017/12/09 00:28:25 UTC
[incubator-mxnet] branch master updated: Minor changes to Caffe
Translator (#8939)
This is an automated email from the ASF dual-hosted git repository.
indhub pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 9d4bb9c Minor changes to Caffe Translator (#8939)
9d4bb9c is described below
commit 9d4bb9c53c00d3b8f8ddd1e9e92ac8cbeb885111
Author: Indhu Bharathi <in...@gmail.com>
AuthorDate: Fri Dec 8 16:28:17 2017 -0800
Minor changes to Caffe Translator (#8939)
* - Add license to string template files.
- Add license to gradlew
- Some bug fixes and refactoring for optimizer generation.
- Language change in comment that goes into generated code.
- Don't generate CaffeLoss layer for Accuracy layer. It is now being translated to MXNet Accuracy metrics.
- Minor bug fix in searching for the correct optimizer template.
- Bump verion up to 0.9.2
* Add license for Optimizer.java
* Code cleanup.
---
tools/caffe_translator/build.gradle | 2 +-
tools/caffe_translator/gradlew | 17 +++++
.../java/io/mxnet/caffetranslator/Converter.java | 75 +++++++---------------
.../java/io/mxnet/caffetranslator/Optimizer.java | 48 ++++++++++++++
.../main/java/io/mxnet/caffetranslator/Solver.java | 52 ++++++++++++++-
.../src/main/resources/templates/accuracy.st | 18 ++++++
.../src/main/resources/templates/activation.st | 18 ++++++
.../src/main/resources/templates/add.st | 18 ++++++
.../src/main/resources/templates/batchnorm.st | 18 ++++++
.../src/main/resources/templates/concat.st | 18 ++++++
.../src/main/resources/templates/convolution.st | 18 ++++++
.../src/main/resources/templates/deconvolution.st | 18 ++++++
.../src/main/resources/templates/dropout.st | 18 ++++++
.../src/main/resources/templates/fc.st | 18 ++++++
.../src/main/resources/templates/flatten.st | 18 ++++++
.../src/main/resources/templates/group.st | 18 ++++++
.../src/main/resources/templates/imports.st | 18 ++++++
.../src/main/resources/templates/init_params.st | 18 ++++++
.../src/main/resources/templates/iterator.st | 18 ++++++
.../src/main/resources/templates/logging.st | 18 ++++++
.../src/main/resources/templates/lrn.st | 18 ++++++
.../src/main/resources/templates/lrpolicy_exp.st | 18 ++++++
.../src/main/resources/templates/lrpolicy_inv.st | 18 ++++++
.../main/resources/templates/lrpolicy_multistep.st | 18 ++++++
.../src/main/resources/templates/lrpolicy_poly.st | 18 ++++++
.../main/resources/templates/lrpolicy_sigmoid.st | 18 ++++++
.../src/main/resources/templates/lrpolicy_step.st | 18 ++++++
.../src/main/resources/templates/maxium.st | 18 ++++++
.../main/resources/templates/metrics_classes.st | 27 ++++++--
.../src/main/resources/templates/mul.st | 18 ++++++
.../src/main/resources/templates/opt_adadelta.st | 32 +++++++++
.../src/main/resources/templates/opt_adagrad.st | 28 ++++++++
.../src/main/resources/templates/opt_adam.st | 36 +++++++++++
.../src/main/resources/templates/opt_default.st | 15 -----
.../src/main/resources/templates/opt_nesterov.st | 28 ++++++++
.../src/main/resources/templates/opt_rmsprop.st | 32 +++++++++
.../src/main/resources/templates/opt_sgd.st | 36 ++++++++---
.../src/main/resources/templates/opt_vars.st | 24 +++++++
.../main/resources/templates/param_initializer.st | 18 ++++++
.../src/main/resources/templates/params_loader.st | 18 ++++++
.../src/main/resources/templates/permute.st | 18 ++++++
.../src/main/resources/templates/pooling.st | 18 ++++++
.../src/main/resources/templates/power.st | 18 ++++++
.../src/main/resources/templates/runner.st | 18 ++++++
.../src/main/resources/templates/softmaxoutput.st | 18 ++++++
.../src/main/resources/templates/symbols.stg | 18 ++++++
.../src/main/resources/templates/top_k_accuracy.st | 18 ++++++
.../src/main/resources/templates/var.st | 18 ++++++
48 files changed, 979 insertions(+), 85 deletions(-)
diff --git a/tools/caffe_translator/build.gradle b/tools/caffe_translator/build.gradle
index 4206767..da5e900 100644
--- a/tools/caffe_translator/build.gradle
+++ b/tools/caffe_translator/build.gradle
@@ -10,7 +10,7 @@ apply plugin: 'maven'
apply plugin: 'signing'
group 'org.caffetranslator'
-version '0.9.1'
+version '0.9.2'
def isReleaseBuild
def repositoryUrl
diff --git a/tools/caffe_translator/gradlew b/tools/caffe_translator/gradlew
index cccdd3d..07cc915 100755
--- a/tools/caffe_translator/gradlew
+++ b/tools/caffe_translator/gradlew
@@ -1,5 +1,22 @@
#!/usr/bin/env sh
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
##############################################################################
##
## Gradle start up script for UN*X
diff --git a/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/Converter.java b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/Converter.java
index 90ed9d2..96d6fec 100644
--- a/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/Converter.java
+++ b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/Converter.java
@@ -154,22 +154,33 @@ public class Converter {
Layer layer = layers.get(layerIndex);
SymbolGenerator generator = generators.getGenerator(layer.getType());
- // If the translator cannot translate this layer to an MXNet layer,
- // use CaffeOp or CaffeLoss instead.
+ // Handle layers for which there is no Generator
if (generator == null) {
- if (layer.getType().toLowerCase().endsWith("loss")) {
+ if (layer.getType().equalsIgnoreCase("Accuracy")) {
+ // We handle accuracy layers at a later stage. Do nothing for now.
+ } else if (layer.getType().toLowerCase().endsWith("loss")) {
+ // This is a loss layer we don't have a generator for. Wrap it in CaffeLoss.
generator = generators.getGenerator("CaffePluginLossLayer");
} else {
+ // This is a layer we don't have a generator for. Wrap it in CaffeOp.
generator = generators.getGenerator("PluginIntLayerGenerator");
}
}
- GeneratorOutput out = generator.generate(layer, mlModel);
- String segment = out.code;
- code.append(segment);
- code.append(NL);
-
- layerIndex += out.numLayersTranslated;
+ if (generator != null) { // If we have a generator
+ // Generate code
+ GeneratorOutput out = generator.generate(layer, mlModel);
+ String segment = out.code;
+ code.append(segment);
+ code.append(NL);
+
+ // Update layerIndex depending on how many layers we ended up translating
+ layerIndex += out.numLayersTranslated;
+ } else { // If we don't have a generator
+ // We've decided to skip this layer. Generate no code. Just increment layerIndex
+ // by 1 and move on to the next layer.
+ layerIndex++;
+ }
}
String loss = getLoss(mlModel, code);
@@ -304,50 +315,8 @@ public class Converter {
}
private String generateOptimizer() {
- String caffeOptimizer = solver.getProperty("type", "sgd").toLowerCase();
- ST st;
-
- String lr = solver.getProperty("base_lr");
- String momentum = solver.getProperty("momentum", "0.9");
- String wd = solver.getProperty("weight_decay", "0.0005");
-
- switch (caffeOptimizer) {
- case "adadelta":
- st = gh.getTemplate("opt_default");
- st.add("opt_name", "AdaDelta");
- st.add("epsilon", solver.getProperty("delta"));
- break;
- case "adagrad":
- st = gh.getTemplate("opt_default");
- st.add("opt_name", "AdaGrad");
- break;
- case "adam":
- st = gh.getTemplate("opt_default");
- st.add("opt_name", "Adam");
- break;
- case "nesterov":
- st = gh.getTemplate("opt_sgd");
- st.add("opt_name", "NAG");
- st.add("momentum", momentum);
- break;
- case "rmsprop":
- st = gh.getTemplate("opt_default");
- st.add("opt_name", "RMSProp");
- break;
- default:
- if (!caffeOptimizer.equals("sgd")) {
- System.err.println("Unknown optimizer. Will use SGD instead.");
- }
-
- st = gh.getTemplate("opt_sgd");
- st.add("opt_name", "SGD");
- st.add("momentum", momentum);
- break;
- }
- st.add("lr", lr);
- st.add("wd", wd);
-
- return st.render();
+ Optimizer optimizer = new Optimizer(solver);
+ return optimizer.generateInitCode();
}
private String generateInitializer() {
diff --git a/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/Optimizer.java b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/Optimizer.java
new file mode 100644
index 0000000..da24942
--- /dev/null
+++ b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/Optimizer.java
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file Optimizer.java
+ * \brief Generates optimizer from solver prototxt
+ */
+
+package io.mxnet.caffetranslator;
+
+import org.stringtemplate.v4.ST;
+
+public class Optimizer {
+ private final GenerationHelper gh;
+ private final Solver solver;
+
+ public Optimizer(Solver solver) {
+ this.gh = new GenerationHelper();
+ this.solver = solver;
+ }
+
+ public String generateInitCode() {
+ ST st = gh.getTemplate("opt_" + solver.getType().toLowerCase());
+ if (st == null) {
+ System.err.println(String.format("Unknown optimizer type (%s). Using SGD instead.", solver.getType()));
+ st = gh.getTemplate("opt_sgd");
+ }
+
+ st.add("solver", solver);
+ return st.render();
+ }
+}
diff --git a/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/Solver.java b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/Solver.java
index ec4c812..9693771 100644
--- a/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/Solver.java
+++ b/tools/caffe_translator/src/main/java/io/mxnet/caffetranslator/Solver.java
@@ -24,6 +24,7 @@
package io.mxnet.caffetranslator;
+import lombok.Getter;
import org.antlr.v4.runtime.CharStream;
import org.antlr.v4.runtime.CharStreams;
import org.antlr.v4.runtime.CommonTokenStream;
@@ -31,6 +32,7 @@ import org.antlr.v4.runtime.CommonTokenStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
+import java.lang.reflect.Field;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.List;
@@ -38,9 +40,18 @@ import java.util.Map;
public class Solver {
+ private final String solverPath;
private boolean parseDone;
private Map<String, List<String>> properties;
- private final String solverPath;
+ /**
+ * Fields corresponding to keys that can be present in the solver prototxt. 'setFields' sets these
+ * using reflection after parsing the solver prototxt. A solver object is passed to string templates
+ * and the templates read these fields.
+ */
+ @Getter
+ private String base_lr, momentum, weight_decay, lr_policy, gamma, stepsize, stepvalue, max_iter,
+ solver_mode, snapshot, snapshot_prefix, test_iter, test_interval, display, type, delta,
+ momentum2, rms_decay, solver_type;
public Solver(String solverPath) {
this.solverPath = solverPath;
@@ -67,10 +78,49 @@ public class Solver {
properties = solverListener.getProperties();
+ setFields(properties);
+
parseDone = true;
return true;
}
+ private void setFields(Map<String, List<String>> properties) {
+ Class<?> cls = getClass();
+
+ for (Map.Entry<String, List<String>> entry : properties.entrySet()) {
+ String key = entry.getKey();
+ try {
+ Field field = cls.getDeclaredField(key);
+ field.set(this, entry.getValue().get(0));
+ } catch (NoSuchFieldException e) {
+ // Just ignore
+ } catch (IllegalAccessException e) {
+ /**
+ * This shouldn't happen. If it does happen because we overlooked something, print
+ * it in the console so we can investigate it.
+ */
+ e.printStackTrace();
+ }
+ }
+
+ setDefaults();
+ }
+
+ private void setDefaults() {
+ if (type == null) {
+ type = "SGD";
+ }
+ if (delta == null) {
+ delta = "1e-8";
+ }
+ if (momentum2 == null) {
+ momentum2 = "0.999";
+ }
+ if (rms_decay == null) {
+ rms_decay = "0.99";
+ }
+ }
+
public String getProperty(String key) {
List<String> list = getProperties(key);
if (list == null) {
diff --git a/tools/caffe_translator/src/main/resources/templates/accuracy.st b/tools/caffe_translator/src/main/resources/templates/accuracy.st
index f741def..cbe15f6 100644
--- a/tools/caffe_translator/src/main/resources/templates/accuracy.st
+++ b/tools/caffe_translator/src/main/resources/templates/accuracy.st
@@ -1,2 +1,20 @@
+<!
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+!>
<var> = mx.metric.Accuracy(output_names=['<output_name>'], label_names=['<label_name>'], name='<name>')
test_metrics.add(<var>)
diff --git a/tools/caffe_translator/src/main/resources/templates/activation.st b/tools/caffe_translator/src/main/resources/templates/activation.st
index 5a9c37b..042c2e3 100644
--- a/tools/caffe_translator/src/main/resources/templates/activation.st
+++ b/tools/caffe_translator/src/main/resources/templates/activation.st
@@ -1 +1,19 @@
+<!
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+!>
<var> = mx.symbol.Activation(data=<data>, act_type='<type>', name='<name>')
diff --git a/tools/caffe_translator/src/main/resources/templates/add.st b/tools/caffe_translator/src/main/resources/templates/add.st
index ca9428f..738ac3e 100644
--- a/tools/caffe_translator/src/main/resources/templates/add.st
+++ b/tools/caffe_translator/src/main/resources/templates/add.st
@@ -1 +1,19 @@
+<!
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+!>
<var> = <data1> + <data2>
diff --git a/tools/caffe_translator/src/main/resources/templates/batchnorm.st b/tools/caffe_translator/src/main/resources/templates/batchnorm.st
index c043c70..7f2326d 100644
--- a/tools/caffe_translator/src/main/resources/templates/batchnorm.st
+++ b/tools/caffe_translator/src/main/resources/templates/batchnorm.st
@@ -1,3 +1,21 @@
+<!
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+!>
<if(fix_beta)>
<var>_beta = mx.sym.BlockGrad(mx.sym.Variable("<name>_beta", init=mx.init.Constant(0)))
<endif>
diff --git a/tools/caffe_translator/src/main/resources/templates/concat.st b/tools/caffe_translator/src/main/resources/templates/concat.st
index 75ffa3c..3f33275 100644
--- a/tools/caffe_translator/src/main/resources/templates/concat.st
+++ b/tools/caffe_translator/src/main/resources/templates/concat.st
@@ -1 +1,19 @@
+<!
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+!>
<var> = mx.sym.concat(<data;separator=", "><if(dim)>, dim=<dim><endif>, name='<name>');
diff --git a/tools/caffe_translator/src/main/resources/templates/convolution.st b/tools/caffe_translator/src/main/resources/templates/convolution.st
index c4bdd51..c167217 100644
--- a/tools/caffe_translator/src/main/resources/templates/convolution.st
+++ b/tools/caffe_translator/src/main/resources/templates/convolution.st
@@ -1,3 +1,21 @@
+<!
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+!>
<var> = mx.sym.Convolution(data=<data>,
<if(weight)>weight=<weight>,<endif>
<if(bias)>bias=<bias>,<endif>
diff --git a/tools/caffe_translator/src/main/resources/templates/deconvolution.st b/tools/caffe_translator/src/main/resources/templates/deconvolution.st
index 5b63f56..67483b9 100644
--- a/tools/caffe_translator/src/main/resources/templates/deconvolution.st
+++ b/tools/caffe_translator/src/main/resources/templates/deconvolution.st
@@ -1,3 +1,21 @@
+<!
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+!>
<var> = mx.sym.Deconvolution(data=<data>,
<if(use_weight)>weight=weight,<endif>
<if(use_bias)>bias=bias,<endif>
diff --git a/tools/caffe_translator/src/main/resources/templates/dropout.st b/tools/caffe_translator/src/main/resources/templates/dropout.st
index 9791c09..ed28dc7 100644
--- a/tools/caffe_translator/src/main/resources/templates/dropout.st
+++ b/tools/caffe_translator/src/main/resources/templates/dropout.st
@@ -1 +1,19 @@
+<!
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+!>
<var> = mx.sym.Dropout(data=<data>, p=<prob>, name='<name>')
diff --git a/tools/caffe_translator/src/main/resources/templates/fc.st b/tools/caffe_translator/src/main/resources/templates/fc.st
index 22365b3..353b424 100644
--- a/tools/caffe_translator/src/main/resources/templates/fc.st
+++ b/tools/caffe_translator/src/main/resources/templates/fc.st
@@ -1 +1,19 @@
+<!
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+!>
<var> = mx.symbol.FullyConnected(data=<data>, <if(weight)>weight=<weight>, <endif><if(bias)>bias=<bias>, <endif>num_hidden=<num>, <if(no_bias)>no_bias=True, <endif>name='<name>')
diff --git a/tools/caffe_translator/src/main/resources/templates/flatten.st b/tools/caffe_translator/src/main/resources/templates/flatten.st
index 8434335..2ee6ffa 100644
--- a/tools/caffe_translator/src/main/resources/templates/flatten.st
+++ b/tools/caffe_translator/src/main/resources/templates/flatten.st
@@ -1 +1,19 @@
+<!
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+!>
<var> = mx.sym.flatten(data=<data>, name='<name>')
diff --git a/tools/caffe_translator/src/main/resources/templates/group.st b/tools/caffe_translator/src/main/resources/templates/group.st
index 33e312f..9cadf65 100644
--- a/tools/caffe_translator/src/main/resources/templates/group.st
+++ b/tools/caffe_translator/src/main/resources/templates/group.st
@@ -1 +1,19 @@
+<!
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+!>
<var> = mx.sym.Group([<symbols;separator=", ">]);
diff --git a/tools/caffe_translator/src/main/resources/templates/imports.st b/tools/caffe_translator/src/main/resources/templates/imports.st
index b37bd33..da03a64 100644
--- a/tools/caffe_translator/src/main/resources/templates/imports.st
+++ b/tools/caffe_translator/src/main/resources/templates/imports.st
@@ -1,3 +1,21 @@
+<!
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+!>
from __future__ import division
import copy
import logging
diff --git a/tools/caffe_translator/src/main/resources/templates/init_params.st b/tools/caffe_translator/src/main/resources/templates/init_params.st
index 3a277b6..7c8d7b0 100644
--- a/tools/caffe_translator/src/main/resources/templates/init_params.st
+++ b/tools/caffe_translator/src/main/resources/templates/init_params.st
@@ -1,3 +1,21 @@
+<!
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+!>
<if(params_file)>
arg_params, aux_params = load_params('<params_file>')
module.init_params(initializer=mx.init.Xavier(), arg_params=arg_params, aux_params=aux_params,
diff --git a/tools/caffe_translator/src/main/resources/templates/iterator.st b/tools/caffe_translator/src/main/resources/templates/iterator.st
index 5bc2a9d..d608979 100644
--- a/tools/caffe_translator/src/main/resources/templates/iterator.st
+++ b/tools/caffe_translator/src/main/resources/templates/iterator.st
@@ -1,3 +1,21 @@
+<!
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+!>
<iter_name> = mx.io.CaffeDataIter(
prototxt =
<prototxt>,
diff --git a/tools/caffe_translator/src/main/resources/templates/logging.st b/tools/caffe_translator/src/main/resources/templates/logging.st
index 73785e5..cc94872 100644
--- a/tools/caffe_translator/src/main/resources/templates/logging.st
+++ b/tools/caffe_translator/src/main/resources/templates/logging.st
@@ -1,3 +1,21 @@
+<!
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+!>
def get_logger(name):
formatter = logging.Formatter(fmt='%(asctime)s %(levelname)-8s %(message)s',
datefmt='%Y-%m-%d %H:%M:%S')
diff --git a/tools/caffe_translator/src/main/resources/templates/lrn.st b/tools/caffe_translator/src/main/resources/templates/lrn.st
index ec003c1..b679898 100644
--- a/tools/caffe_translator/src/main/resources/templates/lrn.st
+++ b/tools/caffe_translator/src/main/resources/templates/lrn.st
@@ -1 +1,19 @@
+<!
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+!>
<var> = mx.sym.LRN(data=<data>, alpha=<alpha>, beta=<beta>, knorm=<knorm>, nsize=<nsize>, name=<name>)
diff --git a/tools/caffe_translator/src/main/resources/templates/lrpolicy_exp.st b/tools/caffe_translator/src/main/resources/templates/lrpolicy_exp.st
index 43afca2..03daae3 100644
--- a/tools/caffe_translator/src/main/resources/templates/lrpolicy_exp.st
+++ b/tools/caffe_translator/src/main/resources/templates/lrpolicy_exp.st
@@ -1,3 +1,21 @@
+<!
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+!>
lr = optimizer_params['learning_rate']
lr *= gamma
optimizer_params['learning_rate'] = lr
diff --git a/tools/caffe_translator/src/main/resources/templates/lrpolicy_inv.st b/tools/caffe_translator/src/main/resources/templates/lrpolicy_inv.st
index 5da8aa6..e62c2d3 100644
--- a/tools/caffe_translator/src/main/resources/templates/lrpolicy_inv.st
+++ b/tools/caffe_translator/src/main/resources/templates/lrpolicy_inv.st
@@ -1,3 +1,21 @@
+<!
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+!>
lr = optimizer_params['learning_rate']
lr = base_lr * math.pow((1 + gamma * batch_num), -power)
optimizer_params['learning_rate'] = lr
diff --git a/tools/caffe_translator/src/main/resources/templates/lrpolicy_multistep.st b/tools/caffe_translator/src/main/resources/templates/lrpolicy_multistep.st
index fe09301..0761908 100644
--- a/tools/caffe_translator/src/main/resources/templates/lrpolicy_multistep.st
+++ b/tools/caffe_translator/src/main/resources/templates/lrpolicy_multistep.st
@@ -1,3 +1,21 @@
+<!
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+!>
lr_update_steps = [<steps;separator=", ">]
if(batch_num in lr_update_steps):
lr = optimizer_params['learning_rate']
diff --git a/tools/caffe_translator/src/main/resources/templates/lrpolicy_poly.st b/tools/caffe_translator/src/main/resources/templates/lrpolicy_poly.st
index e43fd78..d62c64b 100644
--- a/tools/caffe_translator/src/main/resources/templates/lrpolicy_poly.st
+++ b/tools/caffe_translator/src/main/resources/templates/lrpolicy_poly.st
@@ -1,3 +1,21 @@
+<!
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+!>
lr = optimizer_params['learning_rate']
lr = math.pow(base_lr * (1 - batch_num/max_iter), power)
optimizer_params['learning_rate'] = lr
diff --git a/tools/caffe_translator/src/main/resources/templates/lrpolicy_sigmoid.st b/tools/caffe_translator/src/main/resources/templates/lrpolicy_sigmoid.st
index 33ba055..f44ab5a 100644
--- a/tools/caffe_translator/src/main/resources/templates/lrpolicy_sigmoid.st
+++ b/tools/caffe_translator/src/main/resources/templates/lrpolicy_sigmoid.st
@@ -1,3 +1,21 @@
+<!
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+!>
lr = optimizer_params['learning_rate']
lr = base_lr * ( 1/(1 + math.exp(-gamma * (batch_num - stepsize))))
optimizer_params['learning_rate'] = lr
diff --git a/tools/caffe_translator/src/main/resources/templates/lrpolicy_step.st b/tools/caffe_translator/src/main/resources/templates/lrpolicy_step.st
index 04468ae..1f3d975 100644
--- a/tools/caffe_translator/src/main/resources/templates/lrpolicy_step.st
+++ b/tools/caffe_translator/src/main/resources/templates/lrpolicy_step.st
@@ -1,3 +1,21 @@
+<!
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+!>
if(batch_num % stepsize == 0):
lr = optimizer_params['learning_rate']
lr *= gamma
diff --git a/tools/caffe_translator/src/main/resources/templates/maxium.st b/tools/caffe_translator/src/main/resources/templates/maxium.st
index d9431dd..9b18246 100644
--- a/tools/caffe_translator/src/main/resources/templates/maxium.st
+++ b/tools/caffe_translator/src/main/resources/templates/maxium.st
@@ -1 +1,19 @@
+<!
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+!>
<var> = mx.sym.maximum(<data1>, <data2>)
diff --git a/tools/caffe_translator/src/main/resources/templates/metrics_classes.st b/tools/caffe_translator/src/main/resources/templates/metrics_classes.st
index e8323fb..e586616 100644
--- a/tools/caffe_translator/src/main/resources/templates/metrics_classes.st
+++ b/tools/caffe_translator/src/main/resources/templates/metrics_classes.st
@@ -1,3 +1,21 @@
+<!
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+!>
class TrainMetrics():
metric_map = {}
@@ -16,17 +34,16 @@ class TrainMetrics():
self.update_metrics(module, label, reset=True)
self.print_metrics(batch_num)
else:
- # If I'll have to print metrics 'average_loss' iterations from now,
- # append a metric so I can start updating that.
+ # Metrics must be print 'average_loss' iterations from now.
+ # Append a metric which will get updated starting now.
if((batch_num + self.average_loss) % self.display == 0):
self.append_one()
- # If I'm less than 'average_loss' iteration away from a display step,
- # update the metrics.
+ # Less that 'average_loss' iteration away from a display step. Update metrics.
if((batch_num + self.average_loss) % self.display \< self.average_loss):
self.update_metrics(module, label)
- # If I'm at a display step, print the metrics.
+ # At display step. Print metrics.
if(batch_num % self.display == 0):
self.print_metrics(batch_num, remove_heads=True)
diff --git a/tools/caffe_translator/src/main/resources/templates/mul.st b/tools/caffe_translator/src/main/resources/templates/mul.st
index 411a407..59c4837 100644
--- a/tools/caffe_translator/src/main/resources/templates/mul.st
+++ b/tools/caffe_translator/src/main/resources/templates/mul.st
@@ -1 +1,19 @@
+<!
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+!>
<var> = <data1> * (<data2>)
diff --git a/tools/caffe_translator/src/main/resources/templates/opt_adadelta.st b/tools/caffe_translator/src/main/resources/templates/opt_adadelta.st
new file mode 100644
index 0000000..cfd465b
--- /dev/null
+++ b/tools/caffe_translator/src/main/resources/templates/opt_adadelta.st
@@ -0,0 +1,32 @@
+<!
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+!>
+<opt_vars(solver)>
+<if(solver.momentum)>
+rho = <solver.momentum>
+<endif>
+<if(solver.delta)>
+epsilon = <solver.delta>
+<endif>
+
+optimizer_params={'learning_rate':base_lr<\\>
+<if(solver.wd)>, 'wd':wd<endif><\\>
+<if(solver.momentum)>, 'rho':rho<endif><\\>
+<if(solver.delta)>, 'epsilon':epsilon<endif>}<\\>
+
+module.init_optimizer(optimizer='AdaDelta', optimizer_params=optimizer_params)
diff --git a/tools/caffe_translator/src/main/resources/templates/opt_adagrad.st b/tools/caffe_translator/src/main/resources/templates/opt_adagrad.st
new file mode 100644
index 0000000..527cedf
--- /dev/null
+++ b/tools/caffe_translator/src/main/resources/templates/opt_adagrad.st
@@ -0,0 +1,28 @@
+<!
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+!>
+<opt_vars(solver)>
+<if(solver.delta)>
+epsilon = <solver.delta>
+<endif>
+
+optimizer_params={'learning_rate':base_lr<\\>
+<if(solver.wd)>, 'wd':wd<endif><\\>
+<if(solver.delta)>, 'epsilon':epsilon<endif>}<\\>
+
+module.init_optimizer(optimizer='AdaGrad', optimizer_params=optimizer_params)
diff --git a/tools/caffe_translator/src/main/resources/templates/opt_adam.st b/tools/caffe_translator/src/main/resources/templates/opt_adam.st
new file mode 100644
index 0000000..b0a8ca3
--- /dev/null
+++ b/tools/caffe_translator/src/main/resources/templates/opt_adam.st
@@ -0,0 +1,36 @@
+<!
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+!>
+<opt_vars(solver)>
+<if(solver.momentum)>
+beta1 = <solver.momentum>
+<endif>
+<if(solver.momentum2)>
+beta2 = <solver.momentum2>
+<endif>
+<if(solver.delta)>
+epsilon = <solver.delta>
+<endif>
+
+optimizer_params={'learning_rate':base_lr<\\>
+<if(solver.wd)>, 'wd':swd<endif><\\>
+<if(solver.momentum)>, 'beta1':beta1<endif><\\>
+<if(solver.momentum2)>, 'beta2':beta2<endif><\\>
+<if(solver.delta)>, 'epsilon':epsilon<endif>}<\\>
+
+module.init_optimizer(optimizer='Adam', optimizer_params=optimizer_params)
diff --git a/tools/caffe_translator/src/main/resources/templates/opt_default.st b/tools/caffe_translator/src/main/resources/templates/opt_default.st
deleted file mode 100644
index e5a72ac..0000000
--- a/tools/caffe_translator/src/main/resources/templates/opt_default.st
+++ /dev/null
@@ -1,15 +0,0 @@
-<if(lr)>
-base_lr = <lr>
-<endif>
-<if(momentum)>
-momentum = <momentum>
-<endif>
-<if(wd)>
-wd = <wd>
-<endif>
-<if(epsilon)>
-epsilon = <epsilon>
-<endif>
-
-optimizer_params={'learning_rate':base_lr <if(momentum)>, 'momentum':momentum<endif><if(wd)>, 'wd':wd<endif><if(epsilon)>, 'epsilon':epsilon<endif>}
-module.init_optimizer(optimizer='<opt_name>', optimizer_params=optimizer_params)
diff --git a/tools/caffe_translator/src/main/resources/templates/opt_nesterov.st b/tools/caffe_translator/src/main/resources/templates/opt_nesterov.st
new file mode 100644
index 0000000..6262d48
--- /dev/null
+++ b/tools/caffe_translator/src/main/resources/templates/opt_nesterov.st
@@ -0,0 +1,28 @@
+<!
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+!>
+<opt_vars(solver)>
+<if(solver.momentum)>
+momentum = <solver.momentum>
+<endif>
+
+optimizer_params={'learning_rate':base_lr<\\>
+<if(solver.wd)>, 'wd':wd<endif><\\>
+<if(solver.momentum)>, 'momentum':momentum<endif>}<\\>
+
+module.init_optimizer(optimizer='NAG', optimizer_params=optimizer_params)
diff --git a/tools/caffe_translator/src/main/resources/templates/opt_rmsprop.st b/tools/caffe_translator/src/main/resources/templates/opt_rmsprop.st
new file mode 100644
index 0000000..6baec42
--- /dev/null
+++ b/tools/caffe_translator/src/main/resources/templates/opt_rmsprop.st
@@ -0,0 +1,32 @@
+<!
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+!>
+<opt_vars(solver)>
+<if(solver.rms_decay)>
+gamma1 = <solver.rms_decay>
+<endif>
+<if(solver.delta)>
+epsilon = <solver.delta>
+<endif>
+
+optimizer_params={'learning_rate':base_lr<\\>
+<if(solver.wd)>, 'wd':wd<endif><\\>
+<if(solver.rms_decay)>, 'gamma1':gamma1<endif><\\>
+<if(solver.delta)>, 'epsilon':epsilon<endif>}<\\>
+
+module.init_optimizer(optimizer='RMSProp', optimizer_params=optimizer_params)
diff --git a/tools/caffe_translator/src/main/resources/templates/opt_sgd.st b/tools/caffe_translator/src/main/resources/templates/opt_sgd.st
index 8a24e05..aa547a6 100644
--- a/tools/caffe_translator/src/main/resources/templates/opt_sgd.st
+++ b/tools/caffe_translator/src/main/resources/templates/opt_sgd.st
@@ -1,12 +1,28 @@
-<if(lr)>
-base_lr = <lr>
-<endif>
-<if(momentum)>
-momentum = <momentum>
-<endif>
-<if(wd)>
-wd = <wd>
+<!
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+!>
+<opt_vars(solver)>
+<if(solver.momentum)>
+momentum = <solver.momentum>
<endif>
-optimizer_params={'learning_rate':base_lr <if(momentum)>, 'momentum':momentum<endif><if(wd)>, 'wd':wd<endif>}
-module.init_optimizer(optimizer='<opt_name>', optimizer_params=optimizer_params)
+optimizer_params={'learning_rate':base_lr<\\>
+<if(solver.wd)>, 'wd':wd<endif><\\>
+<if(solver.momentum)>, 'momentum':momentum<endif>}<\\>
+
+module.init_optimizer(optimizer='SGD', optimizer_params=optimizer_params)
diff --git a/tools/caffe_translator/src/main/resources/templates/opt_vars.st b/tools/caffe_translator/src/main/resources/templates/opt_vars.st
new file mode 100644
index 0000000..19b2f4c
--- /dev/null
+++ b/tools/caffe_translator/src/main/resources/templates/opt_vars.st
@@ -0,0 +1,24 @@
+<!
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+!>
+<if(solver.base_lr)>
+base_lr = <solver.base_lr>
+<endif>
+<if(solver.wd)>
+wd = <solver.wd>
+<endif>
\ No newline at end of file
diff --git a/tools/caffe_translator/src/main/resources/templates/param_initializer.st b/tools/caffe_translator/src/main/resources/templates/param_initializer.st
index b496fc3..abad5da 100644
--- a/tools/caffe_translator/src/main/resources/templates/param_initializer.st
+++ b/tools/caffe_translator/src/main/resources/templates/param_initializer.st
@@ -1,3 +1,21 @@
+<!
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+!>
class ParamInitializer():
lst_patterns = []
lst_initializers = []
diff --git a/tools/caffe_translator/src/main/resources/templates/params_loader.st b/tools/caffe_translator/src/main/resources/templates/params_loader.st
index 22efec4..c124c98 100644
--- a/tools/caffe_translator/src/main/resources/templates/params_loader.st
+++ b/tools/caffe_translator/src/main/resources/templates/params_loader.st
@@ -1,3 +1,21 @@
+<!
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+!>
def load_params(params_file):
save_dict = mx.nd.load(params_file)
arg_params = {}
diff --git a/tools/caffe_translator/src/main/resources/templates/permute.st b/tools/caffe_translator/src/main/resources/templates/permute.st
index 2b06a76..9f94bdb 100644
--- a/tools/caffe_translator/src/main/resources/templates/permute.st
+++ b/tools/caffe_translator/src/main/resources/templates/permute.st
@@ -1 +1,19 @@
+<!
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+!>
<var> = mx.sym.transpose(data=<data>, axes=(<axes;separator=", ">), name='<name>')
diff --git a/tools/caffe_translator/src/main/resources/templates/pooling.st b/tools/caffe_translator/src/main/resources/templates/pooling.st
index 5389754..7aceffd 100644
--- a/tools/caffe_translator/src/main/resources/templates/pooling.st
+++ b/tools/caffe_translator/src/main/resources/templates/pooling.st
@@ -1,3 +1,21 @@
+<!
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+!>
<var> = mx.symbol.Pooling(data=<data>,
pool_type='<type>',
<if(global_pool)>
diff --git a/tools/caffe_translator/src/main/resources/templates/power.st b/tools/caffe_translator/src/main/resources/templates/power.st
index a512a67..7fe3ee8 100644
--- a/tools/caffe_translator/src/main/resources/templates/power.st
+++ b/tools/caffe_translator/src/main/resources/templates/power.st
@@ -1 +1,19 @@
+<!
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+!>
<var> = (<shift> + (<scale> * <data>)) ** <power>
diff --git a/tools/caffe_translator/src/main/resources/templates/runner.st b/tools/caffe_translator/src/main/resources/templates/runner.st
index 6df9671..8346ffe 100644
--- a/tools/caffe_translator/src/main/resources/templates/runner.st
+++ b/tools/caffe_translator/src/main/resources/templates/runner.st
@@ -1,3 +1,21 @@
+<!
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+!>
ctx = <ctx>
module = mx.mod.Module(symbol=<loss>, context=ctx, data_names=[<data_names;separator=", ">], label_names=[<label_names;separator=", ">])
diff --git a/tools/caffe_translator/src/main/resources/templates/softmaxoutput.st b/tools/caffe_translator/src/main/resources/templates/softmaxoutput.st
index bc63891..57a8e71 100644
--- a/tools/caffe_translator/src/main/resources/templates/softmaxoutput.st
+++ b/tools/caffe_translator/src/main/resources/templates/softmaxoutput.st
@@ -1,3 +1,21 @@
+<!
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+!>
<var> = mx.sym.SoftmaxOutput(data=<data>, label=<label>, name='<name>')
<var>_metric = mx.metric.CrossEntropy(output_names=['<name>_output'], label_names=['<label_name>'], name='<name>/metric')
train_metrics.add(<var>_metric)
diff --git a/tools/caffe_translator/src/main/resources/templates/symbols.stg b/tools/caffe_translator/src/main/resources/templates/symbols.stg
index fda9125..2a76eb0 100644
--- a/tools/caffe_translator/src/main/resources/templates/symbols.stg
+++ b/tools/caffe_translator/src/main/resources/templates/symbols.stg
@@ -1,3 +1,21 @@
+<!
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+!>
CaffePluginIntLayer(var, tops, num_data, num_weight, num_out, data, prototxt, name) ::= "<var> = mx.symbol.CaffeOp(<if(data)><data>, <endif><if(num_data)>num_data=<num_data>, <endif><if(num_out)>num_out=<num_out>, <endif><if(num_weight)>num_weight=<num_weight>, <endif>prototxt='<prototxt>', name='<name>')
<if(tops)><tops:{top|<top_assign(top, var, i0)>};separator=\"\n\"> <endif>"
diff --git a/tools/caffe_translator/src/main/resources/templates/top_k_accuracy.st b/tools/caffe_translator/src/main/resources/templates/top_k_accuracy.st
index de93ee9..29a713f 100644
--- a/tools/caffe_translator/src/main/resources/templates/top_k_accuracy.st
+++ b/tools/caffe_translator/src/main/resources/templates/top_k_accuracy.st
@@ -1,2 +1,20 @@
+<!
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+!>
<var> = mx.metric.TopKAccuracy(top_k=<k>, output_names=['<output_name>'], label_names=['<label_name>'], name='<name>')
test_metrics.add(<var>)
diff --git a/tools/caffe_translator/src/main/resources/templates/var.st b/tools/caffe_translator/src/main/resources/templates/var.st
index e850b689..fa08cd7 100644
--- a/tools/caffe_translator/src/main/resources/templates/var.st
+++ b/tools/caffe_translator/src/main/resources/templates/var.st
@@ -1 +1,19 @@
+<!
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+!>
<var> = mx.sym.Variable('<name>'<if(lr_mult)>, lr_mult=<lr_mult><endif><if(wd_mult)>, wd_mult=<wd_mult><endif><if(init)>, init=<init><endif><if(shape)>, shape=(<shape;separator=", ">)<endif>)
--
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].