You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ha...@apache.org on 2018/10/08 20:06:16 UTC
[incubator-mxnet] branch master updated: fix benchmark on control
flow operators. (#12693)
This is an automated email from the ASF dual-hosted git repository.
haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 077253d fix benchmark on control flow operators. (#12693)
077253d is described below
commit 077253d5ca2dca667bc3c2c1d3cfc04d78c80f29
Author: Da Zheng <zh...@gmail.com>
AuthorDate: Mon Oct 8 13:06:03 2018 -0700
fix benchmark on control flow operators. (#12693)
---
benchmark/python/control_flow/rnn.py | 8 +++++---
1 file changed, 5 insertions(+), 3 deletions(-)
diff --git a/benchmark/python/control_flow/rnn.py b/benchmark/python/control_flow/rnn.py
index 8a44a9c..0849872 100644
--- a/benchmark/python/control_flow/rnn.py
+++ b/benchmark/python/control_flow/rnn.py
@@ -32,6 +32,7 @@ _parser = argparse.ArgumentParser(description='Benchmark foreach and while_loop
_parser.add_argument('--benchmark', choices=["foreach", "while_loop"], required=True)
_parser.add_argument('--warmup_rounds', type=int, default=20)
_parser.add_argument('--test_rounds', type=int, default=100)
+_parser.add_argument('--gpu', type=bool, default=False)
args = _parser.parse_args()
@@ -66,8 +67,7 @@ class WhileRNN(gluon.HybridBlock):
loop_vars=states,
max_iterations=self.length,
)
- assert len(out) == 1
- return out[0]
+ return out
def _zeros(shape, ctx):
@@ -124,7 +124,9 @@ def main():
cell_types = [gluon.rnn.RNNCell,
gluon.rnn.GRUCell,
gluon.rnn.LSTMCell]
- ctxs = [mx.cpu(0)] + [mx.gpu(i) for i in _get_gpus()]
+ ctxs = [mx.cpu(0)]
+ if args.gpu:
+ ctxs = ctxs + [mx.gpu(i) for i in _get_gpus()]
seq_lens = [100]
batch_sizes = [1, 32]
hidden_dims = [512]