You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2017/08/03 05:50:58 UTC

[incubator-mxnet] branch master updated: reduce model zoo test size (#7318)

This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new cffbc2c  reduce model zoo test size (#7318)
cffbc2c is described below

commit cffbc2c5790261a77b217e1f4cc90a2cac7aeb7f
Author: Sheng Zha <sz...@users.noreply.github.com>
AuthorDate: Wed Aug 2 22:50:55 2017 -0700

    reduce model zoo test size (#7318)
    
    * reduce model zoo test size
    
    * add model print back for pretty-print test
---
 tests/python/unittest/test_gluon_model_zoo.py | 8 ++++++--
 1 file changed, 6 insertions(+), 2 deletions(-)

diff --git a/tests/python/unittest/test_gluon_model_zoo.py b/tests/python/unittest/test_gluon_model_zoo.py
index 200037c..db26fd4 100644
--- a/tests/python/unittest/test_gluon_model_zoo.py
+++ b/tests/python/unittest/test_gluon_model_zoo.py
@@ -3,7 +3,10 @@ import mxnet as mx
 from mxnet.gluon import nn
 from mxnet.gluon.model_zoo.custom_layers import HybridConcurrent, Identity
 from mxnet.gluon.model_zoo.vision import get_model
+import sys
 
+def eprint(*args, **kwargs):
+    print(*args, file=sys.stderr, **kwargs)
 
 def test_concurrent():
     model = HybridConcurrent(concat_dim=1)
@@ -43,11 +46,12 @@ def test_models():
     for model_name in all_models:
         test_pretrain = model_name in pretrained_to_test
         model = get_model(model_name, pretrained=test_pretrain)
-        data_shape = (7, 3, 224, 224) if 'inception' not in model_name else (7, 3, 299, 299)
+        data_shape = (2, 3, 224, 224) if 'inception' not in model_name else (2, 3, 299, 299)
+        eprint('testing forward for %s'%model_name)
         print(model)
         if not test_pretrain:
             model.collect_params().initialize()
-        model(mx.nd.random_uniform(shape=data_shape))
+        model(mx.nd.random_uniform(shape=data_shape)).wait_to_read()
 
 
 if __name__ == '__main__':

-- 
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <co...@mxnet.apache.org>'].