You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ta...@apache.org on 2019/02/16 11:06:51 UTC

[incubator-mxnet] branch master updated: Add an inference script providing both accuracy and benchmark result for original wide_n_deep example (#13895)

This is an automated email from the ASF dual-hosted git repository.

taolv pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 8bfbb7d  Add an inference script providing both accuracy and benchmark result for original wide_n_deep example (#13895)
8bfbb7d is described below

commit 8bfbb7de46ce309e1935967ea2dfeb99f8d8a1f0
Author: Shufan <33...@users.noreply.github.com>
AuthorDate: Sat Feb 16 19:06:25 2019 +0800

    Add an inference script providing both accuracy and benchmark result for original wide_n_deep example (#13895)
    
    * Add a inference script can provide both accuracy and benchmark result
    
    * minor changes
    
    * minor fix to use keep similar coding style as other examples
    
    * fix typo
    
    * remove code redundance and other minor changes
    
    * Addressing review comments and minor pylint fix
    
    * remove parameter 'accuracy' to make logic simple
---
 example/sparse/wide_deep/README.md    |   5 +-
 example/sparse/wide_deep/config.py    |  28 +++++++++
 example/sparse/wide_deep/inference.py | 106 ++++++++++++++++++++++++++++++++++
 example/sparse/wide_deep/train.py     |  20 ++-----
 4 files changed, 142 insertions(+), 17 deletions(-)

diff --git a/example/sparse/wide_deep/README.md b/example/sparse/wide_deep/README.md
index 769d723..d0ae8ad 100644
--- a/example/sparse/wide_deep/README.md
+++ b/example/sparse/wide_deep/README.md
@@ -20,5 +20,8 @@
 The example demonstrates how to train [wide and deep model](https://arxiv.org/abs/1606.07792). The [Census Income Data Set](https://archive.ics.uci.edu/ml/datasets/Census+Income) that this example uses for training is hosted by the [UC Irvine Machine Learning Repository](https://archive.ics.uci.edu/ml/datasets/). Tricks of feature engineering are adapted from tensorflow's [wide and deep tutorial](https://github.com/tensorflow/models/tree/master/official/wide_deep).
 
 The final accuracy should be around 85%.
-
+For training:
 - `python train.py`
+
+For inference:
+- `python inference.py`
diff --git a/example/sparse/wide_deep/config.py b/example/sparse/wide_deep/config.py
new file mode 100644
index 0000000..c0d20c4
--- /dev/null
+++ b/example/sparse/wide_deep/config.py
@@ -0,0 +1,28 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# Related to feature engineering, please see preprocess in data.py
+ADULT = {
+    'train': 'adult.data',
+    'test': 'adult.test',
+    'url': 'https://archive.ics.uci.edu/ml/machine-learning-databases/adult/',
+    'num_linear_features': 3000,
+    'num_embed_features': 2,
+    'num_cont_features': 38,
+    'embed_input_dims': [1000, 1000],
+    'hidden_units': [8, 50, 100],
+}
diff --git a/example/sparse/wide_deep/inference.py b/example/sparse/wide_deep/inference.py
new file mode 100644
index 0000000..e14396e
--- /dev/null
+++ b/example/sparse/wide_deep/inference.py
@@ -0,0 +1,106 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import mxnet as mx
+from mxnet.test_utils import *
+from config import *
+from data import get_uci_adult
+from model import wide_deep_model
+import argparse
+import os
+import time
+
+parser = argparse.ArgumentParser(description="Run sparse wide and deep inference",
+                                 formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+parser.add_argument('--num-infer-batch', type=int, default=100,
+                    help='number of batches to inference')
+parser.add_argument('--load-epoch', type=int, default=0,
+                    help='loading the params of the corresponding training epoch.')
+parser.add_argument('--batch-size', type=int, default=100,
+                    help='number of examples per batch')
+parser.add_argument('--benchmark', action='store_true', default=False,
+                    help='run the script for benchmark mode, not set for accuracy test.')
+parser.add_argument('--verbose', action='store_true', default=False,
+                    help='accurcy for each batch will be logged if set')
+parser.add_argument('--gpu', action='store_true', default=False,
+                    help='Inference on GPU with CUDA')
+parser.add_argument('--model-prefix', type=str, default='checkpoint',
+                    help='the model prefix')
+
+if __name__ == '__main__':
+    import logging
+    head = '%(asctime)-15s %(message)s'
+    logging.basicConfig(level=logging.INFO, format=head)
+
+    # arg parser
+    args = parser.parse_args()
+    logging.info(args)
+    num_iters = args.num_infer_batch
+    batch_size = args.batch_size
+    benchmark = args.benchmark
+    verbose = args.verbose
+    model_prefix = args.model_prefix
+    load_epoch = args.load_epoch
+    ctx = mx.gpu(0) if args.gpu else mx.cpu()
+    # dataset
+    data_dir = os.path.join(os.getcwd(), 'data')
+    val_data = os.path.join(data_dir, ADULT['test'])
+    val_csr, val_dns, val_label = get_uci_adult(data_dir, ADULT['test'], ADULT['url'])
+    # load parameters and symbol
+    sym, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, load_epoch)
+    # data iterator
+    eval_data = mx.io.NDArrayIter({'csr_data': val_csr, 'dns_data': val_dns},
+                                  {'softmax_label': val_label}, batch_size,
+                                  shuffle=True, last_batch_handle='discard')
+    # module
+    mod = mx.mod.Module(symbol=sym, context=ctx, data_names=['csr_data', 'dns_data'],
+                        label_names=['softmax_label'])
+    mod.bind(data_shapes=eval_data.provide_data, label_shapes=eval_data.provide_label)
+    # get the sparse weight parameter
+    mod.set_params(arg_params=arg_params, aux_params=aux_params)
+
+    data_iter = iter(eval_data)
+    nbatch = 0
+    if benchmark:
+        logging.info('Inference benchmark started ...')
+        tic = time.time()
+        for i in range(num_iters):
+            try:
+                batch = data_iter.next()
+            except StopIteration:
+                data_iter.reset()
+            else:
+                mod.forward(batch, is_train=False)
+                for output in mod.get_outputs():
+                    output.wait_to_read()
+                nbatch += 1
+        score = (nbatch*batch_size)/(time.time() - tic)
+        logging.info('batch size %d, process %s samples/s' % (batch_size, score))
+    else:
+        logging.info('Inference started ...')
+        # use accuracy as the metric
+        metric = mx.metric.create(['acc'])
+        accuracy_avg = 0.0
+        for batch in data_iter:
+            nbatch += 1
+            metric.reset()
+            mod.forward(batch, is_train=False)
+            mod.update_metric(metric, batch.label)
+            accuracy_avg += metric.get()[1][0]
+            if args.verbose:
+                logging.info('batch %d, accuracy = %s' % (nbatch, metric.get()))
+        logging.info('averged accuracy on eval set is %.5f' % (accuracy_avg/nbatch))
diff --git a/example/sparse/wide_deep/train.py b/example/sparse/wide_deep/train.py
index 6fd81b7..eea7030 100644
--- a/example/sparse/wide_deep/train.py
+++ b/example/sparse/wide_deep/train.py
@@ -17,6 +17,7 @@
 
 import mxnet as mx
 from mxnet.test_utils import *
+from config import *
 from data import get_uci_adult
 from model import wide_deep_model
 import argparse
@@ -31,7 +32,7 @@ parser.add_argument('--batch-size', type=int, default=100,
                     help='number of examples per batch')
 parser.add_argument('--lr', type=float, default=0.001,
                     help='learning rate')
-parser.add_argument('--cuda', action='store_true', default=False,
+parser.add_argument('--gpu', action='store_true', default=False,
                     help='Train on GPU with CUDA')
 parser.add_argument('--optimizer', type=str, default='adam',
                     help='what optimizer to use',
@@ -40,19 +41,6 @@ parser.add_argument('--log-interval', type=int, default=100,
                     help='number of batches to wait before logging training status')
 
 
-# Related to feature engineering, please see preprocess in data.py
-ADULT = {
-    'train': 'adult.data',
-    'test': 'adult.test',
-    'url': 'https://archive.ics.uci.edu/ml/machine-learning-databases/adult/',
-    'num_linear_features': 3000,
-    'num_embed_features': 2,
-    'num_cont_features': 38,
-    'embed_input_dims': [1000, 1000],
-    'hidden_units': [8, 50, 100],
-}
-
-
 if __name__ == '__main__':
     import logging
     head = '%(asctime)-15s %(message)s'
@@ -66,7 +54,7 @@ if __name__ == '__main__':
     optimizer = args.optimizer
     log_interval = args.log_interval
     lr = args.lr
-    ctx = mx.gpu(0) if args.cuda else mx.cpu()
+    ctx = mx.gpu(0) if args.gpu else mx.cpu()
 
     # dataset    
     data_dir = os.path.join(os.getcwd(), 'data')
@@ -88,7 +76,7 @@ if __name__ == '__main__':
                                   shuffle=True, last_batch_handle='discard')
     
     # module
-    mod = mx.mod.Module(symbol=model, context=ctx ,data_names=['csr_data', 'dns_data'],
+    mod = mx.mod.Module(symbol=model, context=ctx, data_names=['csr_data', 'dns_data'],
                         label_names=['softmax_label'])
     mod.bind(data_shapes=train_data.provide_data, label_shapes=train_data.provide_label)
     mod.init_params()