You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2018/10/08 20:06:16 UTC

[incubator-mxnet] branch master updated: fix benchmark on control flow operators. (#12693)

This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 077253d  fix benchmark on control flow operators. (#12693)
077253d is described below

commit 077253d5ca2dca667bc3c2c1d3cfc04d78c80f29
Author: Da Zheng <zh...@gmail.com>
AuthorDate: Mon Oct 8 13:06:03 2018 -0700

    fix benchmark on control flow operators. (#12693)
---
 benchmark/python/control_flow/rnn.py | 8 +++++---
 1 file changed, 5 insertions(+), 3 deletions(-)

diff --git a/benchmark/python/control_flow/rnn.py b/benchmark/python/control_flow/rnn.py
index 8a44a9c..0849872 100644
--- a/benchmark/python/control_flow/rnn.py
+++ b/benchmark/python/control_flow/rnn.py
@@ -32,6 +32,7 @@ _parser = argparse.ArgumentParser(description='Benchmark foreach and while_loop
 _parser.add_argument('--benchmark', choices=["foreach", "while_loop"], required=True)
 _parser.add_argument('--warmup_rounds', type=int, default=20)
 _parser.add_argument('--test_rounds', type=int, default=100)
+_parser.add_argument('--gpu', type=bool, default=False)
 args = _parser.parse_args()
 
 
@@ -66,8 +67,7 @@ class WhileRNN(gluon.HybridBlock):
             loop_vars=states,
             max_iterations=self.length,
         )
-        assert len(out) == 1
-        return out[0]
+        return out
 
 
 def _zeros(shape, ctx):
@@ -124,7 +124,9 @@ def main():
     cell_types = [gluon.rnn.RNNCell,
                   gluon.rnn.GRUCell,
                   gluon.rnn.LSTMCell]
-    ctxs = [mx.cpu(0)] + [mx.gpu(i) for i in _get_gpus()]
+    ctxs = [mx.cpu(0)]
+    if args.gpu:
+        ctxs = ctxs + [mx.gpu(i) for i in _get_gpus()]
     seq_lens = [100]
     batch_sizes = [1, 32]
     hidden_dims = [512]