You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/04/10 21:00:16 UTC

[GitHub] zhreshold commented on a change in pull request #10483: SSD performance optimization and benchmark script

zhreshold commented on a change in pull request #10483: SSD performance optimization and benchmark script
URL: https://github.com/apache/incubator-mxnet/pull/10483#discussion_r180566232
 
 

 ##########
 File path: example/ssd/benchmark_score.py
 ##########
 @@ -0,0 +1,102 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import print_function
+import os
+import sys
+import argparse
+import importlib
+import mxnet as mx
+import time
+#from dataset.iterator import DetRecordIter
+#from config.config import cfg
+#from evaluate.eval_metric import MApMetric, VOC07MApMetric
+import logging
+from symbol.symbol_factory import get_symbol
+from symbol.symbol_factory import get_symbol_train
+from symbol import symbol_builder
+
+
+parser = argparse.ArgumentParser(description='MxNet SSD benchmark')
+parser.add_argument('--network', '-n', type=str, default='vgg16_reduced')
+parser.add_argument('--batch_size', '-b', type=int, default=0)
+parser.add_argument('--shape', '-w', type=int, default=300)
+parser.add_argument('--class_num', '-class', type=int, default=20)
+
+
+def get_data_shapes(batch_size):
+    image_shape = (3, 300, 300)
+    return [('data', (batch_size,)+image_shape)]
+
+def get_data(batch_size):
+    data_shapes = get_data_shapes(batch_size)
+    data = [mx.random.uniform(-1.0, 1.0, shape=shape, ctx=mx.cpu()) for _, shape in data_shapes]
+    batch = mx.io.DataBatch(data, [])
+    return batch
+
+
+if __name__ == '__main__':
+    args = parser.parse_args()
+    network = args.network
+    image_shape = args.shape
+    num_classes = args.class_num
+    b = args.batch_size
+    supported_image_shapes = [300, 512]
+    supported_networks = ['vgg16_reduced', 'inceptionv3', 'resnet50']
+
+    if network not in supported_networks:
+        raise Exception(network + " is not supported")
+
+    if image_shape not in supported_image_shapes:
+       raise Exception("Image shape should be either 300*300 or 512*512!")
+
+    if b == 0:
+        batch_sizes = [1, 2, 4, 8, 16, 32]
+    else:
+        batch_sizes = [b]
+
+    data_shape = (3, image_shape, image_shape)
+    net = get_symbol(network, data_shape[1], num_classes=num_classes,
+                     nms_thresh=0.4, force_suppress=True)
+    
+    num_batches = 100
+    dry_run = 5   # use 5 iterations to warm up
+    
+    for bs in batch_sizes:
+        batch = get_data(bs)
+        mod = mx.mod.Module(net, label_names=None, context=mx.cpu())
+        mod.bind(for_training = False,
+                 inputs_need_grad = False,
+                 data_shapes = get_data_shapes(bs))
+        mod.init_params(initializer=mx.init.Xavier(magnitude=2.))
 
 Review comment:
   try load some pre-trained models to test the `real` perf

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services