You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2018/01/25 18:58:46 UTC
[incubator-mxnet] branch master updated: Fix example
example/reinforcement-learning/a3c (#9046)
This is an automated email from the ASF dual-hosted git repository.
jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 52219e8 Fix example example/reinforcement-learning/a3c (#9046)
52219e8 is described below
commit 52219e8b52761e142fd64242a785d7a956aebf37
Author: mbaijal <30...@users.noreply.github.com>
AuthorDate: Thu Jan 25 10:58:42 2018 -0800
Fix example example/reinforcement-learning/a3c (#9046)
* Fix a3c to make compatible with python3
* make reload compatible with python2
* Update README
* Adding my name to contributors.md
* Update queue to make it compatible with both py2 and py3
* some minor changes
---
CONTRIBUTORS.md | 1 +
example/reinforcement-learning/a3c/README.md | 8 +++++++-
example/reinforcement-learning/a3c/a3c.py | 7 ++++++-
example/reinforcement-learning/a3c/rl_data.py | 15 ++++++++++-----
4 files changed, 24 insertions(+), 7 deletions(-)
diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md
index a1a5a5b..9c68dc2 100644
--- a/CONTRIBUTORS.md
+++ b/CONTRIBUTORS.md
@@ -153,3 +153,4 @@ List of Contributors
* [Marco de Abreu](https://github.com/marcoabreu)
- Marco is the creator of the current MXNet CI.
* [Julian Salazar](https://github.com/JulianSlzr)
+* [Meghna Baijal](https://github.com/mbaijal)
diff --git a/example/reinforcement-learning/a3c/README.md b/example/reinforcement-learning/a3c/README.md
index 63faf92..5eaba66 100644
--- a/example/reinforcement-learning/a3c/README.md
+++ b/example/reinforcement-learning/a3c/README.md
@@ -7,8 +7,14 @@ The algorithm should be mostly correct. However I cannot reproduce the result in
Note this is a generalization of the original algorithm since we use `batch_size` threads for each worker instead of the original 1 thread.
+## Prerequisites
+ - Install OpenAI Gym: `pip install gym`
+ - Install the Atari Env: `pip install gym[atari]`
+ - You may need to install flask: `pip install flask`
+ - You may have to install cv2: `pip install opencv-python`
+
## Usage
run `python a3c.py --batch-size=32 --gpus=0` to run training on gpu 0 with batch-size=32.
run `python launcher.py --gpus=0,1 -n 2 python a3c.py` to launch training on 2 gpus (0 and 1), each gpu has two workers.
-
+Note: You might have to update the path to dmlc-core in launcher.py.
diff --git a/example/reinforcement-learning/a3c/a3c.py b/example/reinforcement-learning/a3c/a3c.py
index 4d89a24..f74ce77 100644
--- a/example/reinforcement-learning/a3c/a3c.py
+++ b/example/reinforcement-learning/a3c/a3c.py
@@ -26,6 +26,11 @@ import os
import gym
from datetime import datetime
import time
+import sys
+try:
+ from importlib import reload
+except ImportError:
+ pass
parser = argparse.ArgumentParser(description='Traing A3C with OpenAI Gym')
parser.add_argument('--test', action='store_true', help='run testing', default=False)
@@ -139,7 +144,7 @@ def train():
module.save_params('%s-%04d.params'%(save_model_prefix, epoch))
- for _ in range(epoch_size/args.t_max):
+ for _ in range(int(epoch_size/args.t_max)):
tic = time.time()
# clear gradients
for exe in module._exec_group.grad_arrays:
diff --git a/example/reinforcement-learning/a3c/rl_data.py b/example/reinforcement-learning/a3c/rl_data.py
index ad78975..70f2853 100644
--- a/example/reinforcement-learning/a3c/rl_data.py
+++ b/example/reinforcement-learning/a3c/rl_data.py
@@ -21,13 +21,18 @@ import numpy as np
import gym
import cv2
import math
-import Queue
from threading import Thread
import time
import multiprocessing
import multiprocessing.pool
from flask import Flask, render_template, Response
import signal
+import sys
+is_py3 = sys.version[0] == '3'
+if is_py3:
+ import queue as queue
+else:
+ import Queue as queue
def make_web(queue):
app = Flask(__name__)
@@ -62,7 +67,7 @@ def visual(X, show=True):
buf = np.zeros((h*n, w*n, X.shape[3]), dtype=np.uint8)
for i in range(N):
x = i%n
- y = i/n
+ y = i//n
buf[h*y:h*(y+1), w*x:w*(x+1), :] = X[i]
if show:
cv2.imshow('a', buf)
@@ -88,7 +93,7 @@ class RLDataIter(object):
self.web_viz = web_viz
if web_viz:
- self.queue = Queue.Queue()
+ self.queue = queue.Queue()
self.thread = Thread(target=make_web, args=(self.queue,))
self.thread.daemon = True
self.thread.start()
@@ -117,7 +122,7 @@ class RLDataIter(object):
reward = np.asarray([i[1] for i in new], dtype=np.float32)
done = np.asarray([i[2] for i in new], dtype=np.float32)
- channels = self.state_.shape[1]/self.input_length
+ channels = self.state_.shape[1]//self.input_length
state = np.zeros_like(self.state_)
state[:,:-channels,:,:] = self.state_[:,channels:,:,:]
for i, (ob, env) in enumerate(zip(new, self.env)):
@@ -151,7 +156,7 @@ class GymDataIter(RLDataIter):
return gym.make(self.game)
def visual(self):
- data = self.state_[:4, -self.state_.shape[1]/self.input_length:, :, :]
+ data = self.state_[:4, -self.state_.shape[1]//self.input_length:, :, :]
return visual(np.asarray(data, dtype=np.uint8), False)
if __name__ == '__main__':
--
To stop receiving notification emails like this one, please contact
jxie@apache.org.