You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2018/01/25 18:58:46 UTC

[incubator-mxnet] branch master updated: Fix example example/reinforcement-learning/a3c (#9046)

This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 52219e8  Fix example  example/reinforcement-learning/a3c (#9046)
52219e8 is described below

commit 52219e8b52761e142fd64242a785d7a956aebf37
Author: mbaijal <30...@users.noreply.github.com>
AuthorDate: Thu Jan 25 10:58:42 2018 -0800

    Fix example  example/reinforcement-learning/a3c (#9046)
    
    * Fix a3c to make compatible with python3
    
    * make reload compatible with python2
    
    * Update README
    
    * Adding my name to contributors.md
    
    * Update queue to make it compatible with both py2 and py3
    
    * some minor changes
---
 CONTRIBUTORS.md                               |  1 +
 example/reinforcement-learning/a3c/README.md  |  8 +++++++-
 example/reinforcement-learning/a3c/a3c.py     |  7 ++++++-
 example/reinforcement-learning/a3c/rl_data.py | 15 ++++++++++-----
 4 files changed, 24 insertions(+), 7 deletions(-)

diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md
index a1a5a5b..9c68dc2 100644
--- a/CONTRIBUTORS.md
+++ b/CONTRIBUTORS.md
@@ -153,3 +153,4 @@ List of Contributors
 * [Marco de Abreu](https://github.com/marcoabreu)
  - Marco is the creator of the current MXNet CI.
 * [Julian Salazar](https://github.com/JulianSlzr)
+* [Meghna Baijal](https://github.com/mbaijal)
diff --git a/example/reinforcement-learning/a3c/README.md b/example/reinforcement-learning/a3c/README.md
index 63faf92..5eaba66 100644
--- a/example/reinforcement-learning/a3c/README.md
+++ b/example/reinforcement-learning/a3c/README.md
@@ -7,8 +7,14 @@ The algorithm should be mostly correct. However I cannot reproduce the result in
 
 Note this is a generalization of the original algorithm since we use `batch_size` threads for each worker instead of the original 1 thread.
 
+## Prerequisites
+  - Install OpenAI Gym: `pip install gym`
+  - Install the Atari Env: `pip install gym[atari]`
+  - You may need to install flask: `pip install flask`
+  - You may have to install cv2: `pip install opencv-python`
+
 ## Usage
 run `python a3c.py --batch-size=32 --gpus=0` to run training on gpu 0 with batch-size=32.
 
 run `python launcher.py --gpus=0,1 -n 2 python a3c.py` to launch training on 2 gpus (0 and 1), each gpu has two workers.
-
+Note: You might have to update the path to dmlc-core in launcher.py.
diff --git a/example/reinforcement-learning/a3c/a3c.py b/example/reinforcement-learning/a3c/a3c.py
index 4d89a24..f74ce77 100644
--- a/example/reinforcement-learning/a3c/a3c.py
+++ b/example/reinforcement-learning/a3c/a3c.py
@@ -26,6 +26,11 @@ import os
 import gym
 from datetime import datetime
 import time
+import sys
+try:
+    from importlib import reload
+except ImportError:
+    pass
 
 parser = argparse.ArgumentParser(description='Traing A3C with OpenAI Gym')
 parser.add_argument('--test', action='store_true', help='run testing', default=False)
@@ -139,7 +144,7 @@ def train():
             module.save_params('%s-%04d.params'%(save_model_prefix, epoch))
 
 
-        for _ in range(epoch_size/args.t_max):
+        for _ in range(int(epoch_size/args.t_max)):
             tic = time.time()
             # clear gradients
             for exe in module._exec_group.grad_arrays:
diff --git a/example/reinforcement-learning/a3c/rl_data.py b/example/reinforcement-learning/a3c/rl_data.py
index ad78975..70f2853 100644
--- a/example/reinforcement-learning/a3c/rl_data.py
+++ b/example/reinforcement-learning/a3c/rl_data.py
@@ -21,13 +21,18 @@ import numpy as np
 import gym
 import cv2
 import math
-import Queue
 from threading import Thread
 import time
 import multiprocessing
 import multiprocessing.pool
 from flask import Flask, render_template, Response
 import signal
+import sys
+is_py3 = sys.version[0] == '3'
+if is_py3:
+    import queue as queue
+else:
+    import Queue as queue
 
 def make_web(queue):
     app = Flask(__name__)
@@ -62,7 +67,7 @@ def visual(X, show=True):
     buf = np.zeros((h*n, w*n, X.shape[3]), dtype=np.uint8)
     for i in range(N):
         x = i%n
-        y = i/n
+        y = i//n
         buf[h*y:h*(y+1), w*x:w*(x+1), :] = X[i]
     if show:
         cv2.imshow('a', buf)
@@ -88,7 +93,7 @@ class RLDataIter(object):
 
         self.web_viz = web_viz
         if web_viz:
-            self.queue = Queue.Queue()
+            self.queue = queue.Queue()
             self.thread = Thread(target=make_web, args=(self.queue,))
             self.thread.daemon = True
             self.thread.start()
@@ -117,7 +122,7 @@ class RLDataIter(object):
         reward = np.asarray([i[1] for i in new], dtype=np.float32)
         done = np.asarray([i[2] for i in new], dtype=np.float32)
 
-        channels = self.state_.shape[1]/self.input_length
+        channels = self.state_.shape[1]//self.input_length
         state = np.zeros_like(self.state_)
         state[:,:-channels,:,:] = self.state_[:,channels:,:,:]
         for i, (ob, env) in enumerate(zip(new, self.env)):
@@ -151,7 +156,7 @@ class GymDataIter(RLDataIter):
         return gym.make(self.game)
 
     def visual(self):
-        data = self.state_[:4, -self.state_.shape[1]/self.input_length:, :, :]
+        data = self.state_[:4, -self.state_.shape[1]//self.input_length:, :, :]
         return visual(np.asarray(data, dtype=np.uint8), False)
 
 if __name__ == '__main__':

-- 
To stop receiving notification emails like this one, please contact
jxie@apache.org.