You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2018/01/05 20:01:01 UTC
[incubator-mxnet] branch master updated: Fix example/module folder
and remove duplicate examples (#8964)
This is an automated email from the ASF dual-hosted git repository.
jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new c3d6cf0 Fix example/module folder and remove duplicate examples (#8964)
c3d6cf0 is described below
commit c3d6cf09007dc11e1a5e2227075422911aa2eb3e
Author: yuruofeifei <yu...@gmail.com>
AuthorDate: Fri Jan 5 12:00:57 2018 -0800
Fix example/module folder and remove duplicate examples (#8964)
* Fix example/module folder and remove duplicate examples
* Add readme
* remove relative path insert
---
example/module/README.md | 9 +
example/module/lstm_bucketing.py | 103 -------
example/module/mnist_mlp.py | 1 -
example/module/python_loss.py | 2 +-
example/module/sequential_module.py | 1 -
example/module/train_cifar10.ipynb | 593 ------------------------------------
example/module/train_cifar10.py | 215 -------------
7 files changed, 10 insertions(+), 914 deletions(-)
diff --git a/example/module/README.md b/example/module/README.md
new file mode 100644
index 0000000..99dd756
--- /dev/null
+++ b/example/module/README.md
@@ -0,0 +1,9 @@
+# Module Usage Example
+
+This folder contains usage examples for MXNet module.
+
+[mnist_mlp.py](https://github.com/apache/incubator-mxnet/blob/master/example/module/mnist_mlp.py): Trains a simple multilayer perceptron on the MNIST dataset
+
+[python_loss](https://github.com/apache/incubator-mxnet/blob/master/example/module/python_loss.py): Usage example for PythonLossModule
+
+[sequential_module](https://github.com/apache/incubator-mxnet/blob/master/example/module/sequential_module.py): Usage example for SequentialModule
diff --git a/example/module/lstm_bucketing.py b/example/module/lstm_bucketing.py
deleted file mode 100644
index ecc7e7b..0000000
--- a/example/module/lstm_bucketing.py
+++ /dev/null
@@ -1,103 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-# pylint: disable=C0111,too-many-arguments,too-many-instance-attributes,too-many-locals,redefined-outer-name,fixme
-# pylint: disable=superfluous-parens, no-member, invalid-name
-import sys
-import os
-sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "python")))
-sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "rnn")))
-import numpy as np
-import mxnet as mx
-
-from lstm import lstm_unroll
-from bucket_io import BucketSentenceIter, default_build_vocab
-
-import os.path
-data_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'rnn', 'data'))
-
-def Perplexity(label, pred):
- label = label.T.reshape((-1,))
- loss = 0.
- for i in range(pred.shape[0]):
- loss += -np.log(max(1e-10, pred[i][int(label[i])]))
- return np.exp(loss / label.size)
-
-if __name__ == '__main__':
- batch_size = 32
- buckets = [10, 20, 30, 40, 50, 60]
- #buckets = [32]
- num_hidden = 200
- num_embed = 200
- num_lstm_layer = 2
-
- #num_epoch = 25
- num_epoch = 2
- learning_rate = 0.01
- momentum = 0.0
-
- # dummy data is used to test speed without IO
- dummy_data = False
-
- contexts = [mx.context.gpu(i) for i in range(1)]
-
- vocab = default_build_vocab(os.path.join(data_dir, "ptb.train.txt"))
-
- init_c = [('l%d_init_c'%l, (batch_size, num_hidden)) for l in range(num_lstm_layer)]
- init_h = [('l%d_init_h'%l, (batch_size, num_hidden)) for l in range(num_lstm_layer)]
- init_states = init_c + init_h
-
- data_train = BucketSentenceIter(os.path.join(data_dir, "ptb.train.txt"), vocab,
- buckets, batch_size, init_states)
- data_val = BucketSentenceIter(os.path.join(data_dir, "ptb.valid.txt"), vocab,
- buckets, batch_size, init_states)
-
- if dummy_data:
- data_train = DummyIter(data_train)
- data_val = DummyIter(data_val)
-
- state_names = [x[0] for x in init_states]
- def sym_gen(seq_len):
- sym = lstm_unroll(num_lstm_layer, seq_len, len(vocab),
- num_hidden=num_hidden, num_embed=num_embed,
- num_label=len(vocab))
- data_names = ['data'] + state_names
- label_names = ['softmax_label']
- return (sym, data_names, label_names)
-
- if len(buckets) == 1:
- mod = mx.mod.Module(*sym_gen(buckets[0]), context=contexts)
- else:
- mod = mx.mod.BucketingModule(sym_gen, default_bucket_key=data_train.default_bucket_key, context=contexts)
-
- import logging
- head = '%(asctime)-15s %(message)s'
- logging.basicConfig(level=logging.DEBUG, format=head)
-
- mod.fit(data_train, eval_data=data_val, num_epoch=num_epoch,
- eval_metric=mx.metric.np(Perplexity),
- batch_end_callback=mx.callback.Speedometer(batch_size, 50),
- initializer=mx.init.Xavier(factor_type="in", magnitude=2.34),
- optimizer='sgd',
- optimizer_params={'learning_rate':0.01, 'momentum': 0.9, 'wd': 0.00001})
-
- # Now it is very easy to use the bucketing to do scoring or collect prediction outputs
- metric = mx.metric.np(Perplexity)
- mod.score(data_val, metric)
- for name, val in metric.get_name_value():
- logging.info('Validation-%s=%f', name, val)
-
diff --git a/example/module/mnist_mlp.py b/example/module/mnist_mlp.py
index d2737dc..7d63a58 100644
--- a/example/module/mnist_mlp.py
+++ b/example/module/mnist_mlp.py
@@ -17,7 +17,6 @@
# pylint: skip-file
import os, sys
-sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from utils import get_data
import mxnet as mx
import numpy as np
diff --git a/example/module/python_loss.py b/example/module/python_loss.py
index 9680ac6..0e15210 100644
--- a/example/module/python_loss.py
+++ b/example/module/python_loss.py
@@ -25,7 +25,7 @@ import logging
@numba.jit
def mc_hinge_grad(scores, labels):
scores = scores.asnumpy()
- labels = labels.asnumpy()
+ labels = labels.asnumpy().astype(int)
n, _ = scores.shape
grad = np.zeros_like(scores)
diff --git a/example/module/sequential_module.py b/example/module/sequential_module.py
index 48e1046..11a548d 100644
--- a/example/module/sequential_module.py
+++ b/example/module/sequential_module.py
@@ -17,7 +17,6 @@
# pylint: skip-file
import os, sys
-sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from utils import get_data
import mxnet as mx
import numpy as np
diff --git a/example/module/train_cifar10.ipynb b/example/module/train_cifar10.ipynb
deleted file mode 100644
index 4cfb973..0000000
--- a/example/module/train_cifar10.ipynb
+++ /dev/null
@@ -1,593 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# Train CIFAR-10 CNN model\n",
- "using MXNet's \"Module\" interface"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {
- "collapsed": false
- },
- "outputs": [],
- "source": [
- "import mxnet\n",
- "import mxnet as mx\n",
- "import train_cifar10"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {
- "collapsed": true
- },
- "outputs": [],
- "source": [
- "# Set up the hyper-parameters\n",
- "args = train_cifar10.command_line_args(defaults=True)\n",
- "args.gpus = \"0\"\n",
- "#args.network = \"lenet\" # Fast, not very accurate\n",
- "#args.network = \"inception-bn-28-small\" # Much more accurate & slow"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {
- "collapsed": false
- },
- "outputs": [
- {
- "data": {
- "text/html": [
- "\n",
- " <div class=\"bk-root\">\n",
- " <a href=\"http://bokeh.pydata.org\" target=\"_blank\" class=\"bk-logo bk-logo-small bk-logo-notebook\"></a>\n",
- " <span id=\"32effd9a-c11b-4a18-86b1-2c6781cf5027\">Loading BokehJS ...</span>\n",
- " </div>"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "application/javascript": [
- "\n",
- "(function(global) {\n",
- " function now() {\n",
- " return new Date();\n",
- " }\n",
- "\n",
- " var force = \"1\";\n",
- "\n",
- " if (typeof (window._bokeh_onload_callbacks) === \"undefined\" || force !== \"\") {\n",
- " window._bokeh_onload_callbacks = [];\n",
- " window._bokeh_is_loading = undefined;\n",
- " }\n",
- "\n",
- "\n",
- " \n",
- " if (typeof (window._bokeh_timeout) === \"undefined\" || force !== \"\") {\n",
- " window._bokeh_timeout = Date.now() + 5000;\n",
- " window._bokeh_failed_load = false;\n",
- " }\n",
- "\n",
- " var NB_LOAD_WARNING = {'data': {'text/html':\n",
- " \"<div style='background-color: #fdd'>\\n\"+\n",
- " \"<p>\\n\"+\n",
- " \"BokehJS does not appear to have successfully loaded. If loading BokehJS from CDN, this \\n\"+\n",
- " \"may be due to a slow or bad network connection. Possible fixes:\\n\"+\n",
- " \"</p>\\n\"+\n",
- " \"<ul>\\n\"+\n",
- " \"<li>re-rerun `output_notebook()` to attempt to load from CDN again, or</li>\\n\"+\n",
- " \"<li>use INLINE resources instead, as so:</li>\\n\"+\n",
- " \"</ul>\\n\"+\n",
- " \"<code>\\n\"+\n",
- " \"from bokeh.resources import INLINE\\n\"+\n",
- " \"output_notebook(resources=INLINE)\\n\"+\n",
- " \"</code>\\n\"+\n",
- " \"</div>\"}};\n",
- "\n",
- " function display_loaded() {\n",
- " if (window.Bokeh !== undefined) {\n",
- " Bokeh.$(\"#32effd9a-c11b-4a18-86b1-2c6781cf5027\").text(\"BokehJS successfully loaded.\");\n",
- " } else if (Date.now() < window._bokeh_timeout) {\n",
- " setTimeout(display_loaded, 100)\n",
- " }\n",
- " }\n",
- "\n",
- " function run_callbacks() {\n",
- " window._bokeh_onload_callbacks.forEach(function(callback) { callback() });\n",
- " delete window._bokeh_onload_callbacks\n",
- " console.info(\"Bokeh: all callbacks have finished\");\n",
- " }\n",
- "\n",
- " function load_libs(js_urls, callback) {\n",
- " window._bokeh_onload_callbacks.push(callback);\n",
- " if (window._bokeh_is_loading > 0) {\n",
- " console.log(\"Bokeh: BokehJS is being loaded, scheduling callback at\", now());\n",
- " return null;\n",
- " }\n",
- " if (js_urls == null || js_urls.length === 0) {\n",
- " run_callbacks();\n",
- " return null;\n",
- " }\n",
- " console.log(\"Bokeh: BokehJS not loaded, scheduling load and callback at\", now());\n",
- " window._bokeh_is_loading = js_urls.length;\n",
- " for (var i = 0; i < js_urls.length; i++) {\n",
- " var url = js_urls[i];\n",
- " var s = document.createElement('script');\n",
- " s.src = url;\n",
- " s.async = false;\n",
- " s.onreadystatechange = s.onload = function() {\n",
- " window._bokeh_is_loading--;\n",
- " if (window._bokeh_is_loading === 0) {\n",
- " console.log(\"Bokeh: all BokehJS libraries loaded\");\n",
- " run_callbacks()\n",
- " }\n",
- " };\n",
- " s.onerror = function() {\n",
- " console.warn(\"failed to load library \" + url);\n",
- " };\n",
- " console.log(\"Bokeh: injecting script tag for BokehJS library: \", url);\n",
- " document.getElementsByTagName(\"head\")[0].appendChild(s);\n",
- " }\n",
- " };var element = document.getElementById(\"32effd9a-c11b-4a18-86b1-2c6781cf5027\");\n",
- " if (element == null) {\n",
- " console.log(\"Bokeh: ERROR: autoload.js configured with elementid '32effd9a-c11b-4a18-86b1-2c6781cf5027' but no matching script tag was found. \")\n",
- " return false;\n",
- " }\n",
- "\n",
- " var js_urls = ['https://cdn.pydata.org/bokeh/release/bokeh-0.12.3.min.js', 'https://cdn.pydata.org/bokeh/release/bokeh-widgets-0.12.3.min.js'];\n",
- "\n",
- " var inline_js = [\n",
- " function(Bokeh) {\n",
- " Bokeh.set_log_level(\"info\");\n",
- " },\n",
- " \n",
- " function(Bokeh) {\n",
- " \n",
- " Bokeh.$(\"#32effd9a-c11b-4a18-86b1-2c6781cf5027\").text(\"BokehJS is loading...\");\n",
- " },\n",
- " function(Bokeh) {\n",
- " console.log(\"Bokeh: injecting CSS: https://cdn.pydata.org/bokeh/release/bokeh-0.12.3.min.css\");\n",
- " Bokeh.embed.inject_css(\"https://cdn.pydata.org/bokeh/release/bokeh-0.12.3.min.css\");\n",
- " console.log(\"Bokeh: injecting CSS: https://cdn.pydata.org/bokeh/release/bokeh-widgets-0.12.3.min.css\");\n",
- " Bokeh.embed.inject_css(\"https://cdn.pydata.org/bokeh/release/bokeh-widgets-0.12.3.min.css\");\n",
- " }\n",
- " ];\n",
- "\n",
- " function run_inline_js() {\n",
- " \n",
- " if ((window.Bokeh !== undefined) || (force === \"1\")) {\n",
- " for (var i = 0; i < inline_js.length; i++) {\n",
- " inline_js[i](window.Bokeh);\n",
- " }if (force === \"1\") {\n",
- " display_loaded();\n",
- " }} else if (Date.now() < window._bokeh_timeout) {\n",
- " setTimeout(run_inline_js, 100);\n",
- " } else if (!window._bokeh_failed_load) {\n",
- " console.log(\"Bokeh: BokehJS failed to load within specified timeout.\");\n",
- " window._bokeh_failed_load = true;\n",
- " } else if (!force) {\n",
- " var cell = $(\"#32effd9a-c11b-4a18-86b1-2c6781cf5027\").parents('.cell').data().cell;\n",
- " cell.output_area.append_execute_result(NB_LOAD_WARNING)\n",
- " }\n",
- "\n",
- " }\n",
- "\n",
- " if (window._bokeh_is_loading === 0) {\n",
- " console.log(\"Bokeh: BokehJS loaded, going straight to plotting\");\n",
- " run_inline_js();\n",
- " } else {\n",
- " load_libs(js_urls, function() {\n",
- " console.log(\"Bokeh: BokehJS plotting callback run at\", now());\n",
- " run_inline_js();\n",
- " });\n",
- " }\n",
- "}(this));"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- " <div class=\"bk-root\">\n",
- " <div class=\"plotdiv\" id=\"8813280e-065e-435d-846c-31a252dddcc7\"></div>\n",
- " </div>\n",
- "<script type=\"text/javascript\">\n",
- " \n",
- " (function(global) {\n",
- " function now() {\n",
- " return new Date();\n",
- " }\n",
- " \n",
- " var force = \"\";\n",
- " \n",
- " if (typeof (window._bokeh_onload_callbacks) === \"undefined\" || force !== \"\") {\n",
- " window._bokeh_onload_callbacks = [];\n",
- " window._bokeh_is_loading = undefined;\n",
- " }\n",
- " \n",
- " \n",
- " \n",
- " if (typeof (window._bokeh_timeout) === \"undefined\" || force !== \"\") {\n",
- " window._bokeh_timeout = Date.now() + 0;\n",
- " window._bokeh_failed_load = false;\n",
- " }\n",
- " \n",
- " var NB_LOAD_WARNING = {'data': {'text/html':\n",
- " \"<div style='background-color: #fdd'>\\n\"+\n",
- " \"<p>\\n\"+\n",
- " \"BokehJS does not appear to have successfully loaded. If loading BokehJS from CDN, this \\n\"+\n",
- " \"may be due to a slow or bad network connection. Possible fixes:\\n\"+\n",
- " \"</p>\\n\"+\n",
- " \"<ul>\\n\"+\n",
- " \"<li>re-rerun `output_notebook()` to attempt to load from CDN again, or</li>\\n\"+\n",
- " \"<li>use INLINE resources instead, as so:</li>\\n\"+\n",
- " \"</ul>\\n\"+\n",
- " \"<code>\\n\"+\n",
- " \"from bokeh.resources import INLINE\\n\"+\n",
- " \"output_notebook(resources=INLINE)\\n\"+\n",
- " \"</code>\\n\"+\n",
- " \"</div>\"}};\n",
- " \n",
- " function display_loaded() {\n",
- " if (window.Bokeh !== undefined) {\n",
- " Bokeh.$(\"#8813280e-065e-435d-846c-31a252dddcc7\").text(\"BokehJS successfully loaded.\");\n",
- " } else if (Date.now() < window._bokeh_timeout) {\n",
- " setTimeout(display_loaded, 100)\n",
- " }\n",
- " }if ((window.Jupyter !== undefined) && Jupyter.notebook.kernel) {\n",
- " comm_manager = Jupyter.notebook.kernel.comm_manager\n",
- " comm_manager.register_target(\"29afa51c-a944-4f51-8ffa-a03b925a1f47\", function () {});\n",
- " }\n",
- " \n",
- " function run_callbacks() {\n",
- " window._bokeh_onload_callbacks.forEach(function(callback) { callback() });\n",
- " delete window._bokeh_onload_callbacks\n",
- " console.info(\"Bokeh: all callbacks have finished\");\n",
- " }\n",
- " \n",
- " function load_libs(js_urls, callback) {\n",
- " window._bokeh_onload_callbacks.push(callback);\n",
- " if (window._bokeh_is_loading > 0) {\n",
- " console.log(\"Bokeh: BokehJS is being loaded, scheduling callback at\", now());\n",
- " return null;\n",
- " }\n",
- " if (js_urls == null || js_urls.length === 0) {\n",
- " run_callbacks();\n",
- " return null;\n",
- " }\n",
- " console.log(\"Bokeh: BokehJS not loaded, scheduling load and callback at\", now());\n",
- " window._bokeh_is_loading = js_urls.length;\n",
- " for (var i = 0; i < js_urls.length; i++) {\n",
- " var url = js_urls[i];\n",
- " var s = document.createElement('script');\n",
- " s.src = url;\n",
- " s.async = false;\n",
- " s.onreadystatechange = s.onload = function() {\n",
- " window._bokeh_is_loading--;\n",
- " if (window._bokeh_is_loading === 0) {\n",
- " console.log(\"Bokeh: all BokehJS libraries loaded\");\n",
- " run_callbacks()\n",
- " }\n",
- " };\n",
- " s.onerror = function() {\n",
- " console.warn(\"failed to load library \" + url);\n",
- " };\n",
- " console.log(\"Bokeh: injecting script tag for BokehJS library: \", url);\n",
- " document.getElementsByTagName(\"head\")[0].appendChild(s);\n",
- " }\n",
- " };var element = document.getElementById(\"8813280e-065e-435d-846c-31a252dddcc7\");\n",
- " if (element == null) {\n",
- " console.log(\"Bokeh: ERROR: autoload.js configured with elementid '8813280e-065e-435d-846c-31a252dddcc7' but no matching script tag was found. \")\n",
- " return false;\n",
- " }\n",
- " \n",
- " var js_urls = [];\n",
- " \n",
- " var inline_js = [\n",
- " function(Bokeh) {\n",
- " Bokeh.$(function() {\n",
- " var docs_json = {\"f847dff6-ab8a-4fd1-8be0-e79b39334749\":{\"roots\":{\"references\":[{\"attributes\":{\"months\":[0,6]},\"id\":\"dd14ae70-3ab7-4fd4-8fae-0e8db050826a\",\"type\":\"MonthsTicker\"},{\"attributes\":{\"months\":[0,1,2,3,4,5,6,7,8,9,10,11]},\"id\":\"e21f9d30-e2e0-491a-abaf-5ead1f475c72\",\"type\":\"MonthsTicker\"},{\"attributes\":{\"fill_alpha\":{\"value\":0.1},\"fill_color\":{\"value\":\"#1f77b4\"},\"line_alpha\":{\"value\":0.1},\"line_color\":{\"value\": [...]
- " var render_items = [{\"docid\":\"f847dff6-ab8a-4fd1-8be0-e79b39334749\",\"elementid\":\"8813280e-065e-435d-846c-31a252dddcc7\",\"modelid\":\"01107ba2-ae75-41b0-bddc-5e93690ea7e8\",\"notebook_comms_target\":\"29afa51c-a944-4f51-8ffa-a03b925a1f47\"}];\n",
- " \n",
- " Bokeh.embed.embed_items(docs_json, render_items);\n",
- " });\n",
- " },\n",
- " function(Bokeh) {\n",
- " }\n",
- " ];\n",
- " \n",
- " function run_inline_js() {\n",
- " \n",
- " if ((window.Bokeh !== undefined) || (force === \"1\")) {\n",
- " for (var i = 0; i < inline_js.length; i++) {\n",
- " inline_js[i](window.Bokeh);\n",
- " }if (force === \"1\") {\n",
- " display_loaded();\n",
- " }} else if (Date.now() < window._bokeh_timeout) {\n",
- " setTimeout(run_inline_js, 100);\n",
- " } else if (!window._bokeh_failed_load) {\n",
- " console.log(\"Bokeh: BokehJS failed to load within specified timeout.\");\n",
- " window._bokeh_failed_load = true;\n",
- " } else if (!force) {\n",
- " var cell = $(\"#8813280e-065e-435d-846c-31a252dddcc7\").parents('.cell').data().cell;\n",
- " cell.output_area.append_execute_result(NB_LOAD_WARNING)\n",
- " }\n",
- " \n",
- " }\n",
- " \n",
- " if (window._bokeh_is_loading === 0) {\n",
- " console.log(\"Bokeh: BokehJS loaded, going straight to plotting\");\n",
- " run_inline_js();\n",
- " } else {\n",
- " load_libs(js_urls, function() {\n",
- " console.log(\"Bokeh: BokehJS plotting callback run at\", now());\n",
- " run_inline_js();\n",
- " });\n",
- " }\n",
- " }(this));\n",
- "</script>"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
- "source": [
- "# Configure charts to plot while training\n",
- "from mxnet.notebook.callback import LiveLearningCurve\n",
- "cb_args = LiveLearningCurve('accuracy', 5).callback_args()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {
- "collapsed": false,
- "scrolled": false
- },
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "2016-11-04 18:05:30,714 Node[0] start with arguments Namespace(batch_size=128, data_dir='/efs/datasets/users/leodirac/code/workplace/leodirac/mxnet/example/image-classification/cifar10/', gpus='0', kv_store='local', load_epoch=None, lr=0.05, lr_factor=1, lr_factor_epoch=1, model_prefix=None, network='inception-bn-28-small', num_epochs=20, num_examples=60000, save_model_prefix=None)\n",
- "2016-11-04 18:05:30,715 Node[0] running on ip-172-31-59-245\n",
- "2016-11-04 18:05:32,172 Node[0] Starting with devices [gpu(0)]\n",
- "2016-11-04 18:05:32,175 Node[0] start training for 20 epochs...\n",
- "2016-11-04 18:07:11,468 Node[0] Epoch[0] Train-accuracy=0.566211\n",
- "2016-11-04 18:07:11,469 Node[0] Epoch[0] Train-top_k_accuracy_5=0.948242\n",
- "2016-11-04 18:07:11,470 Node[0] Epoch[0] Train-top_k_accuracy_10=1.000000\n",
- "2016-11-04 18:07:11,471 Node[0] Epoch[0] Train-top_k_accuracy_20=1.000000\n",
- "2016-11-04 18:07:11,471 Node[0] Epoch[0] Time cost=98.542\n",
- "2016-11-04 18:07:17,454 Node[0] Epoch[0] Validation-accuracy=nan\n",
- "2016-11-04 18:07:17,455 Node[0] Epoch[0] Validation-top_k_accuracy_5=nan\n",
- "2016-11-04 18:07:17,456 Node[0] Epoch[0] Validation-top_k_accuracy_10=nan\n",
- "2016-11-04 18:07:17,457 Node[0] Epoch[0] Validation-top_k_accuracy_20=nan\n",
- "2016-11-04 18:08:57,069 Node[0] Epoch[1] Train-accuracy=0.679883\n",
- "2016-11-04 18:08:57,069 Node[0] Epoch[1] Train-top_k_accuracy_5=0.973047\n",
- "2016-11-04 18:08:57,071 Node[0] Epoch[1] Train-top_k_accuracy_10=1.000000\n",
- "2016-11-04 18:08:57,072 Node[0] Epoch[1] Train-top_k_accuracy_20=1.000000\n",
- "2016-11-04 18:08:57,072 Node[0] Epoch[1] Time cost=99.614\n",
- "2016-11-04 18:09:02,598 Node[0] Epoch[1] Validation-accuracy=nan\n",
- "2016-11-04 18:09:02,599 Node[0] Epoch[1] Validation-top_k_accuracy_5=nan\n",
- "2016-11-04 18:09:02,599 Node[0] Epoch[1] Validation-top_k_accuracy_10=nan\n",
- "2016-11-04 18:09:02,600 Node[0] Epoch[1] Validation-top_k_accuracy_20=nan\n",
- "2016-11-04 18:10:42,219 Node[0] Epoch[2] Train-accuracy=0.742388\n",
- "2016-11-04 18:10:42,221 Node[0] Epoch[2] Train-top_k_accuracy_5=0.981170\n",
- "2016-11-04 18:10:42,222 Node[0] Epoch[2] Train-top_k_accuracy_10=1.000000\n",
- "2016-11-04 18:10:42,223 Node[0] Epoch[2] Train-top_k_accuracy_20=1.000000\n",
- "2016-11-04 18:10:42,223 Node[0] Epoch[2] Time cost=99.622\n",
- "2016-11-04 18:10:47,755 Node[0] Epoch[2] Validation-accuracy=nan\n",
- "2016-11-04 18:10:47,756 Node[0] Epoch[2] Validation-top_k_accuracy_5=nan\n",
- "2016-11-04 18:10:47,757 Node[0] Epoch[2] Validation-top_k_accuracy_10=nan\n",
- "2016-11-04 18:10:47,758 Node[0] Epoch[2] Validation-top_k_accuracy_20=nan\n",
- "2016-11-04 18:12:27,771 Node[0] Epoch[3] Train-accuracy=0.775586\n",
- "2016-11-04 18:12:27,772 Node[0] Epoch[3] Train-top_k_accuracy_5=0.987891\n",
- "2016-11-04 18:12:27,773 Node[0] Epoch[3] Train-top_k_accuracy_10=1.000000\n",
- "2016-11-04 18:12:27,774 Node[0] Epoch[3] Train-top_k_accuracy_20=1.000000\n",
- "2016-11-04 18:12:27,775 Node[0] Epoch[3] Time cost=100.017\n",
- "2016-11-04 18:12:33,304 Node[0] Epoch[3] Validation-accuracy=nan\n",
- "2016-11-04 18:12:33,305 Node[0] Epoch[3] Validation-top_k_accuracy_5=nan\n",
- "2016-11-04 18:12:33,306 Node[0] Epoch[3] Validation-top_k_accuracy_10=nan\n",
- "2016-11-04 18:12:33,307 Node[0] Epoch[3] Validation-top_k_accuracy_20=nan\n",
- "2016-11-04 18:14:13,252 Node[0] Epoch[4] Train-accuracy=0.793945\n",
- "2016-11-04 18:14:13,253 Node[0] Epoch[4] Train-top_k_accuracy_5=0.991016\n",
- "2016-11-04 18:14:13,253 Node[0] Epoch[4] Train-top_k_accuracy_10=1.000000\n",
- "2016-11-04 18:14:13,254 Node[0] Epoch[4] Train-top_k_accuracy_20=1.000000\n",
- "2016-11-04 18:14:13,255 Node[0] Epoch[4] Time cost=99.948\n",
- "2016-11-04 18:14:18,787 Node[0] Epoch[4] Validation-accuracy=nan\n",
- "2016-11-04 18:14:18,787 Node[0] Epoch[4] Validation-top_k_accuracy_5=nan\n",
- "2016-11-04 18:14:18,788 Node[0] Epoch[4] Validation-top_k_accuracy_10=nan\n",
- "2016-11-04 18:14:18,789 Node[0] Epoch[4] Validation-top_k_accuracy_20=nan\n",
- "2016-11-04 18:15:58,508 Node[0] Epoch[5] Train-accuracy=0.819311\n",
- "2016-11-04 18:15:58,509 Node[0] Epoch[5] Train-top_k_accuracy_5=0.991987\n",
- "2016-11-04 18:15:58,510 Node[0] Epoch[5] Train-top_k_accuracy_10=1.000000\n",
- "2016-11-04 18:15:58,511 Node[0] Epoch[5] Train-top_k_accuracy_20=1.000000\n",
- "2016-11-04 18:15:58,511 Node[0] Epoch[5] Time cost=99.722\n",
- "2016-11-04 18:16:04,043 Node[0] Epoch[5] Validation-accuracy=nan\n",
- "2016-11-04 18:16:04,044 Node[0] Epoch[5] Validation-top_k_accuracy_5=nan\n",
- "2016-11-04 18:16:04,045 Node[0] Epoch[5] Validation-top_k_accuracy_10=nan\n",
- "2016-11-04 18:16:04,046 Node[0] Epoch[5] Validation-top_k_accuracy_20=nan\n",
- "2016-11-04 18:17:43,899 Node[0] Epoch[6] Train-accuracy=0.829688\n",
- "2016-11-04 18:17:43,900 Node[0] Epoch[6] Train-top_k_accuracy_5=0.993750\n",
- "2016-11-04 18:17:43,901 Node[0] Epoch[6] Train-top_k_accuracy_10=1.000000\n",
- "2016-11-04 18:17:43,902 Node[0] Epoch[6] Train-top_k_accuracy_20=1.000000\n",
- "2016-11-04 18:17:43,903 Node[0] Epoch[6] Time cost=99.856\n",
- "2016-11-04 18:17:49,441 Node[0] Epoch[6] Validation-accuracy=nan\n",
- "2016-11-04 18:17:49,441 Node[0] Epoch[6] Validation-top_k_accuracy_5=nan\n",
- "2016-11-04 18:17:49,442 Node[0] Epoch[6] Validation-top_k_accuracy_10=nan\n",
- "2016-11-04 18:17:49,443 Node[0] Epoch[6] Validation-top_k_accuracy_20=nan\n",
- "2016-11-04 18:19:29,163 Node[0] Epoch[7] Train-accuracy=0.844151\n",
- "2016-11-04 18:19:29,164 Node[0] Epoch[7] Train-top_k_accuracy_5=0.994992\n",
- "2016-11-04 18:19:29,165 Node[0] Epoch[7] Train-top_k_accuracy_10=1.000000\n",
- "2016-11-04 18:19:29,166 Node[0] Epoch[7] Train-top_k_accuracy_20=1.000000\n",
- "2016-11-04 18:19:29,167 Node[0] Epoch[7] Time cost=99.723\n",
- "2016-11-04 18:19:34,688 Node[0] Epoch[7] Validation-accuracy=nan\n",
- "2016-11-04 18:19:34,689 Node[0] Epoch[7] Validation-top_k_accuracy_5=nan\n",
- "2016-11-04 18:19:34,689 Node[0] Epoch[7] Validation-top_k_accuracy_10=nan\n",
- "2016-11-04 18:19:34,690 Node[0] Epoch[7] Validation-top_k_accuracy_20=nan\n",
- "2016-11-04 18:21:15,117 Node[0] Epoch[8] Train-accuracy=0.862695\n",
- "2016-11-04 18:21:15,118 Node[0] Epoch[8] Train-top_k_accuracy_5=0.995703\n",
- "2016-11-04 18:21:15,118 Node[0] Epoch[8] Train-top_k_accuracy_10=1.000000\n",
- "2016-11-04 18:21:15,120 Node[0] Epoch[8] Train-top_k_accuracy_20=1.000000\n",
- "2016-11-04 18:21:15,121 Node[0] Epoch[8] Time cost=100.430\n",
- "2016-11-04 18:21:21,087 Node[0] Epoch[8] Validation-accuracy=nan\n",
- "2016-11-04 18:21:21,088 Node[0] Epoch[8] Validation-top_k_accuracy_5=nan\n",
- "2016-11-04 18:21:21,088 Node[0] Epoch[8] Validation-top_k_accuracy_10=nan\n",
- "2016-11-04 18:21:21,089 Node[0] Epoch[8] Validation-top_k_accuracy_20=nan\n",
- "2016-11-04 18:23:00,978 Node[0] Epoch[9] Train-accuracy=0.872070\n",
- "2016-11-04 18:23:00,979 Node[0] Epoch[9] Train-top_k_accuracy_5=0.994141\n",
- "2016-11-04 18:23:00,981 Node[0] Epoch[9] Train-top_k_accuracy_10=1.000000\n",
- "2016-11-04 18:23:00,982 Node[0] Epoch[9] Train-top_k_accuracy_20=1.000000\n",
- "2016-11-04 18:23:00,983 Node[0] Epoch[9] Time cost=99.893\n",
- "2016-11-04 18:23:06,510 Node[0] Epoch[9] Validation-accuracy=nan\n",
- "2016-11-04 18:23:06,511 Node[0] Epoch[9] Validation-top_k_accuracy_5=nan\n",
- "2016-11-04 18:23:06,511 Node[0] Epoch[9] Validation-top_k_accuracy_10=nan\n",
- "2016-11-04 18:23:06,512 Node[0] Epoch[9] Validation-top_k_accuracy_20=nan\n",
- "2016-11-04 18:24:46,159 Node[0] Epoch[10] Train-accuracy=0.878606\n",
- "2016-11-04 18:24:46,160 Node[0] Epoch[10] Train-top_k_accuracy_5=0.997196\n",
- "2016-11-04 18:24:46,161 Node[0] Epoch[10] Train-top_k_accuracy_10=1.000000\n",
- "2016-11-04 18:24:46,162 Node[0] Epoch[10] Train-top_k_accuracy_20=1.000000\n",
- "2016-11-04 18:24:46,163 Node[0] Epoch[10] Time cost=99.646\n",
- "2016-11-04 18:24:51,684 Node[0] Epoch[10] Validation-accuracy=nan\n",
- "2016-11-04 18:24:51,685 Node[0] Epoch[10] Validation-top_k_accuracy_5=nan\n",
- "2016-11-04 18:24:51,685 Node[0] Epoch[10] Validation-top_k_accuracy_10=nan\n",
- "2016-11-04 18:24:51,686 Node[0] Epoch[10] Validation-top_k_accuracy_20=nan\n",
- "2016-11-04 18:26:31,614 Node[0] Epoch[11] Train-accuracy=0.889062\n",
- "2016-11-04 18:26:31,615 Node[0] Epoch[11] Train-top_k_accuracy_5=0.996094\n",
- "2016-11-04 18:26:31,616 Node[0] Epoch[11] Train-top_k_accuracy_10=1.000000\n",
- "2016-11-04 18:26:31,617 Node[0] Epoch[11] Train-top_k_accuracy_20=1.000000\n",
- "2016-11-04 18:26:31,618 Node[0] Epoch[11] Time cost=99.931\n",
- "2016-11-04 18:26:37,150 Node[0] Epoch[11] Validation-accuracy=nan\n",
- "2016-11-04 18:26:37,151 Node[0] Epoch[11] Validation-top_k_accuracy_5=nan\n",
- "2016-11-04 18:26:37,152 Node[0] Epoch[11] Validation-top_k_accuracy_10=nan\n",
- "2016-11-04 18:26:37,152 Node[0] Epoch[11] Validation-top_k_accuracy_20=nan\n",
- "2016-11-04 18:28:17,021 Node[0] Epoch[12] Train-accuracy=0.895117\n",
- "2016-11-04 18:28:17,023 Node[0] Epoch[12] Train-top_k_accuracy_5=0.997266\n",
- "2016-11-04 18:28:17,024 Node[0] Epoch[12] Train-top_k_accuracy_10=1.000000\n",
- "2016-11-04 18:28:17,025 Node[0] Epoch[12] Train-top_k_accuracy_20=1.000000\n",
- "2016-11-04 18:28:17,026 Node[0] Epoch[12] Time cost=99.873\n",
- "2016-11-04 18:28:22,553 Node[0] Epoch[12] Validation-accuracy=nan\n",
- "2016-11-04 18:28:22,554 Node[0] Epoch[12] Validation-top_k_accuracy_5=nan\n",
- "2016-11-04 18:28:22,555 Node[0] Epoch[12] Validation-top_k_accuracy_10=nan\n",
- "2016-11-04 18:28:22,556 Node[0] Epoch[12] Validation-top_k_accuracy_20=nan\n",
- "2016-11-04 18:30:02,235 Node[0] Epoch[13] Train-accuracy=0.904447\n",
- "2016-11-04 18:30:02,236 Node[0] Epoch[13] Train-top_k_accuracy_5=0.997997\n",
- "2016-11-04 18:30:02,237 Node[0] Epoch[13] Train-top_k_accuracy_10=1.000000\n",
- "2016-11-04 18:30:02,238 Node[0] Epoch[13] Train-top_k_accuracy_20=1.000000\n",
- "2016-11-04 18:30:02,239 Node[0] Epoch[13] Time cost=99.683\n",
- "2016-11-04 18:30:07,772 Node[0] Epoch[13] Validation-accuracy=nan\n",
- "2016-11-04 18:30:07,773 Node[0] Epoch[13] Validation-top_k_accuracy_5=nan\n",
- "2016-11-04 18:30:07,774 Node[0] Epoch[13] Validation-top_k_accuracy_10=nan\n",
- "2016-11-04 18:30:07,774 Node[0] Epoch[13] Validation-top_k_accuracy_20=nan\n",
- "2016-11-04 18:31:47,692 Node[0] Epoch[14] Train-accuracy=0.902734\n",
- "2016-11-04 18:31:47,693 Node[0] Epoch[14] Train-top_k_accuracy_5=0.998633\n",
- "2016-11-04 18:31:47,694 Node[0] Epoch[14] Train-top_k_accuracy_10=1.000000\n",
- "2016-11-04 18:31:47,695 Node[0] Epoch[14] Train-top_k_accuracy_20=1.000000\n",
- "2016-11-04 18:31:47,696 Node[0] Epoch[14] Time cost=99.921\n",
- "2016-11-04 18:31:53,226 Node[0] Epoch[14] Validation-accuracy=nan\n",
- "2016-11-04 18:31:53,227 Node[0] Epoch[14] Validation-top_k_accuracy_5=nan\n",
- "2016-11-04 18:31:53,228 Node[0] Epoch[14] Validation-top_k_accuracy_10=nan\n",
- "2016-11-04 18:31:53,228 Node[0] Epoch[14] Validation-top_k_accuracy_20=nan\n",
- "2016-11-04 18:33:32,868 Node[0] Epoch[15] Train-accuracy=0.919471\n",
- "2016-11-04 18:33:32,869 Node[0] Epoch[15] Train-top_k_accuracy_5=0.998798\n",
- "2016-11-04 18:33:32,870 Node[0] Epoch[15] Train-top_k_accuracy_10=1.000000\n",
- "2016-11-04 18:33:32,871 Node[0] Epoch[15] Train-top_k_accuracy_20=1.000000\n",
- "2016-11-04 18:33:32,872 Node[0] Epoch[15] Time cost=99.643\n",
- "2016-11-04 18:33:38,396 Node[0] Epoch[15] Validation-accuracy=nan\n",
- "2016-11-04 18:33:38,397 Node[0] Epoch[15] Validation-top_k_accuracy_5=nan\n",
- "2016-11-04 18:33:38,398 Node[0] Epoch[15] Validation-top_k_accuracy_10=nan\n",
- "2016-11-04 18:33:38,398 Node[0] Epoch[15] Validation-top_k_accuracy_20=nan\n",
- "2016-11-04 18:35:18,826 Node[0] Epoch[16] Train-accuracy=0.915234\n",
- "2016-11-04 18:35:18,827 Node[0] Epoch[16] Train-top_k_accuracy_5=0.997656\n",
- "2016-11-04 18:35:18,828 Node[0] Epoch[16] Train-top_k_accuracy_10=1.000000\n",
- "2016-11-04 18:35:18,829 Node[0] Epoch[16] Train-top_k_accuracy_20=1.000000\n",
- "2016-11-04 18:35:18,830 Node[0] Epoch[16] Time cost=100.431\n",
- "2016-11-04 18:35:24,759 Node[0] Epoch[16] Validation-accuracy=nan\n",
- "2016-11-04 18:35:24,759 Node[0] Epoch[16] Validation-top_k_accuracy_5=nan\n",
- "2016-11-04 18:35:24,760 Node[0] Epoch[16] Validation-top_k_accuracy_10=nan\n",
- "2016-11-04 18:35:24,761 Node[0] Epoch[16] Validation-top_k_accuracy_20=nan\n",
- "2016-11-04 18:37:04,619 Node[0] Epoch[17] Train-accuracy=0.918555\n",
- "2016-11-04 18:37:04,620 Node[0] Epoch[17] Train-top_k_accuracy_5=0.998437\n",
- "2016-11-04 18:37:04,621 Node[0] Epoch[17] Train-top_k_accuracy_10=1.000000\n",
- "2016-11-04 18:37:04,621 Node[0] Epoch[17] Train-top_k_accuracy_20=1.000000\n",
- "2016-11-04 18:37:04,622 Node[0] Epoch[17] Time cost=99.861\n",
- "2016-11-04 18:37:10,154 Node[0] Epoch[17] Validation-accuracy=nan\n",
- "2016-11-04 18:37:10,155 Node[0] Epoch[17] Validation-top_k_accuracy_5=nan\n",
- "2016-11-04 18:37:10,156 Node[0] Epoch[17] Validation-top_k_accuracy_10=nan\n",
- "2016-11-04 18:37:10,156 Node[0] Epoch[17] Validation-top_k_accuracy_20=nan\n",
- "2016-11-04 18:38:49,811 Node[0] Epoch[18] Train-accuracy=0.924479\n",
- "2016-11-04 18:38:49,812 Node[0] Epoch[18] Train-top_k_accuracy_5=0.998598\n",
- "2016-11-04 18:38:49,813 Node[0] Epoch[18] Train-top_k_accuracy_10=1.000000\n",
- "2016-11-04 18:38:49,814 Node[0] Epoch[18] Train-top_k_accuracy_20=1.000000\n",
- "2016-11-04 18:38:49,814 Node[0] Epoch[18] Time cost=99.658\n",
- "2016-11-04 18:38:55,339 Node[0] Epoch[18] Validation-accuracy=nan\n",
- "2016-11-04 18:38:55,340 Node[0] Epoch[18] Validation-top_k_accuracy_5=nan\n",
- "2016-11-04 18:38:55,340 Node[0] Epoch[18] Validation-top_k_accuracy_10=nan\n",
- "2016-11-04 18:38:55,341 Node[0] Epoch[18] Validation-top_k_accuracy_20=nan\n",
- "2016-11-04 18:40:35,220 Node[0] Epoch[19] Train-accuracy=0.927148\n",
- "2016-11-04 18:40:35,222 Node[0] Epoch[19] Train-top_k_accuracy_5=0.998828\n",
- "2016-11-04 18:40:35,222 Node[0] Epoch[19] Train-top_k_accuracy_10=1.000000\n",
- "2016-11-04 18:40:35,223 Node[0] Epoch[19] Train-top_k_accuracy_20=1.000000\n",
- "2016-11-04 18:40:35,224 Node[0] Epoch[19] Time cost=99.880\n",
- "2016-11-04 18:40:40,751 Node[0] Epoch[19] Validation-accuracy=nan\n",
- "2016-11-04 18:40:40,752 Node[0] Epoch[19] Validation-top_k_accuracy_5=nan\n",
- "2016-11-04 18:40:40,753 Node[0] Epoch[19] Validation-top_k_accuracy_10=nan\n",
- "2016-11-04 18:40:40,753 Node[0] Epoch[19] Validation-top_k_accuracy_20=nan\n"
- ]
- }
- ],
- "source": [
- "# Start training\n",
- "train_cifar10.do_train(args, \n",
- " callback_args=cb_args,\n",
- ")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "collapsed": true
- },
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 2",
- "language": "python",
- "name": "python2"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 2
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython2",
- "version": "2.7.6"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 1
-}
diff --git a/example/module/train_cifar10.py b/example/module/train_cifar10.py
deleted file mode 100644
index a96e8d9..0000000
--- a/example/module/train_cifar10.py
+++ /dev/null
@@ -1,215 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""Train CIFAR-10 classifier in MXNet.
-Demonstrates using the Module class.
-"""
-import logging
-import os, sys
-sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "image-classification")))
-
-import find_mxnet
-import mxnet as mx
-import argparse
-import train_model
-import importlib
-import platform
-
-
-def command_line_args(defaults=False):
- parser = argparse.ArgumentParser(description=__doc__,
- formatter_class=argparse.RawTextHelpFormatter)
- parser.add_argument('--network', type=str, default='inception-bn-28-small',
- help = 'which CNN style to use')
- my_dir = os.path.dirname(__file__)
- default_data_dir = os.path.abspath(os.path.join(my_dir, '..', 'image-classification', 'cifar10')) + '/'
- parser.add_argument('--data-dir', type=str, default=default_data_dir,
- help='the input data directory')
- parser.add_argument('--gpus', type=str,
- help='the gpus will be used, e.g "0,1,2,3"')
- parser.add_argument('--num-examples', type=int, default=60000,
- help='the number of training examples')
- parser.add_argument('--batch-size', type=int, default=128,
- help='the batch size')
- parser.add_argument('--lr', type=float, default=.05,
- help='the initial learning rate')
- parser.add_argument('--lr-factor', type=float, default=1,
- help='times the lr with a factor for every lr-factor-epoch epoch')
- parser.add_argument('--lr-factor-epoch', type=float, default=1,
- help='the number of epoch to factor the lr, could be .5')
- parser.add_argument('--model-prefix', type=str,
- help='the prefix of the model to load')
- parser.add_argument('--save-model-prefix', type=str,
- help='the prefix of the model to save')
- parser.add_argument('--num-epochs', type=int, default=20,
- help='the number of training epochs')
- parser.add_argument('--load-epoch', type=int,
- help="load the model on an epoch using the model-prefix")
- parser.add_argument('--kv-store', type=str, default='local',
- help='the kvstore type')
- if defaults:
- return parser.parse_args([])
- else:
- return parser.parse_args()
-
-
-# download data if necessary
-def _download(data_dir):
- if not os.path.isdir(data_dir):
- os.system("mkdir " + data_dir)
- cwd = os.path.abspath(os.getcwd())
- os.chdir(data_dir)
- if (not os.path.exists('train.rec')) or \
- (not os.path.exists('test.rec')) :
- import urllib, zipfile, glob
- dirname = os.getcwd()
- zippath = os.path.join(dirname, "cifar10.zip")
- urllib.urlretrieve("http://data.mxnet.io/mxnet/data/cifar10.zip", zippath)
- zf = zipfile.ZipFile(zippath, "r")
- zf.extractall()
- zf.close()
- os.remove(zippath)
- for f in glob.glob(os.path.join(dirname, "cifar", "*")):
- name = f.split(os.path.sep)[-1]
- os.rename(f, os.path.join(dirname, name))
- os.rmdir(os.path.join(dirname, "cifar"))
- os.chdir(cwd)
-
-# data
-def get_iterator(args, kv):
- data_shape = (3, 28, 28)
- data_dir = args.data_dir
- if os.name == "nt":
- data_dir = data_dir[:-1] + "\\"
- if '://' not in args.data_dir:
- _download(data_dir)
-
- train = mx.io.ImageRecordIter(
- path_imgrec = data_dir + "train.rec",
- mean_img = data_dir + "mean.bin",
- data_shape = data_shape,
- batch_size = args.batch_size,
- rand_crop = True,
- rand_mirror = True,
- num_parts = kv.num_workers,
- part_index = kv.rank)
-
- val = mx.io.ImageRecordIter(
- path_imgrec = data_dir + "test.rec",
- mean_img = data_dir + "mean.bin",
- rand_crop = False,
- rand_mirror = False,
- data_shape = data_shape,
- batch_size = args.batch_size,
- num_parts = kv.num_workers,
- part_index = kv.rank)
-
- return (train, val)
-
-
-def do_train(args, callback_args=None):
- # network
- net = importlib.import_module('symbol_' + args.network).get_symbol(10)
-
- my_dir = os.path.dirname(__file__)
- if args.model_prefix is not None:
- args.model_prefix = os.path.abspath(os.path.join(my_dir, args.model_prefix))
- if args.save_model_prefix is not None:
- args.save_model_prefix = os.path.abspath(os.path.join(my_dir, args.save_model_prefix))
-
-
- ################################################################################
- # train
- ################################################################################
-
- # kvstore
- kv = mx.kvstore.create(args.kv_store)
-
- # logging
- head = '%(asctime)-15s Node[' + str(kv.rank) + '] %(message)s'
- logging.basicConfig(level=logging.DEBUG, format=head)
- logging.info('start with arguments %s', args)
-
- logging.info('running on %s', platform.node())
-
- (train, val) = get_iterator(args, kv)
-
- if args.gpus is None or args.gpus == '':
- devs = mx.cpu()
- elif type(args.gpus) == str:
- devs = [mx.gpu(int(i)) for i in args.gpus.split(',')]
- else:
- devs = mx.gpu(int(args.gpus))
- logging.info('Starting with devices %s', devs)
-
- mod = mx.mod.Module(net, context=devs)
-
- # load model
- model_prefix = args.model_prefix
-
- if args.load_epoch is not None:
- assert model_prefix is not None
- logging.info('loading model from %s-%d...' % (model_prefix, args.load_epoch))
- sym, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, args.load_epoch)
- else:
- arg_params = None
- aux_params = None
-
- # save model
- save_model_prefix = args.save_model_prefix
- if save_model_prefix is None:
- save_model_prefix = model_prefix
- checkpoint = None if save_model_prefix is None else mx.callback.do_checkpoint(save_model_prefix)
-
- optim_args = {'learning_rate': args.lr, 'wd': 0.00001, 'momentum': 0.9}
- if 'lr_factor' in args and args.lr_factor < 1:
- optim_args['lr_scheduler'] = mx.lr_scheduler.FactorScheduler(
- step = max(int(epoch_size * args.lr_factor_epoch), 1),
- factor = args.lr_factor)
-
- if 'clip_gradient' in args and args.clip_gradient is not None:
- optim_args['clip_gradient'] = args.clip_gradient
-
- eval_metrics = ['accuracy']
- ## TopKAccuracy only allows top_k > 1
- for top_k in [5, 10, 20]:
- eval_metrics.append(mx.metric.create('top_k_accuracy', top_k = top_k))
-
- if args.load_epoch:
- begin_epoch = args.load_epoch+1
- else:
- begin_epoch = 0
-
- if not callback_args:
- callback_args = {
- 'batch_end_callback': mx.callback.Speedometer(args.batch_size, 50),
- 'epoch_end_callback': checkpoint,
- }
- else:
- pass
- #TODO: add checkpoint back in
-
- logging.info('start training for %d epochs...', args.num_epochs)
- mod.fit(train, eval_data=val, optimizer_params=optim_args,
- eval_metric=eval_metrics, num_epoch=args.num_epochs,
- arg_params=arg_params, aux_params=aux_params, begin_epoch=begin_epoch,
- **callback_args)
-
-if __name__ == "__main__":
- args = command_line_args()
- do_train(args)
-
--
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].