piiswrong closed pull request #11051: Fix DQN example
URL: https://github.com/apache/incubator-mxnet/pull/11051
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/example/reinforcement-learning/dqn/README.md 
b/example/reinforcement-learning/dqn/README.md
index fd32667a1f8..4547904b595 100644
Binary files a/example/reinforcement-learning/dqn/README.md and 
b/example/reinforcement-learning/dqn/README.md differ
diff --git a/example/reinforcement-learning/dqn/dqn_run_test.py 
b/example/reinforcement-learning/dqn/dqn_run_test.py
old mode 100644
new mode 100755
index 2abf273978f..e8f36b97976
--- a/example/reinforcement-learning/dqn/dqn_run_test.py
+++ b/example/reinforcement-learning/dqn/dqn_run_test.py
@@ -1,3 +1,5 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
 # Licensed to the Apache Software Foundation (ASF) under one
 # or more contributor license agreements.  See the NOTICE file
 # distributed with this work for additional information
@@ -89,8 +91,8 @@ def calculate_avg_reward(game, qnet, test_steps=125000, 
exploartion=0.05):
                     current_state = game.current_state()
                     state = nd.array(current_state.reshape((1,) + 
current_state.shape),
                                      ctx=qnet.ctx) / float(255.0)
-                    action = nd.argmax_channel(
-                        qnet.forward(is_train=False, data=state)[0]).asscalar()
+                    action = int(nd.argmax_channel(
+                        qnet.forward(is_train=False, 
data=state)[0]).asscalar())
             else:
                 action = npy_rng.randint(action_num)
 
@@ -120,7 +122,7 @@ def main():
                         help='Running Context. E.g `-c gpu` or `-c gpu1` or 
`-c cpu`')
     parser.add_argument('-e', '--epoch-range', required=False, type=str, 
default='22',
                         help='Epochs to run testing. E.g `-e 0,80`, `-e 
0,80,2`')
-    parser.add_argument('-v', '--visualization', required=False, type=int, 
default=0,
+    parser.add_argument('-v', '--visualization', action='store_true',
                         help='Visualize the runs.')
     parser.add_argument('--symbol', required=False, type=str, default="nature",
                         help='type of network, nature or nips')
diff --git a/example/reinforcement-learning/dqn/setup.sh 
b/example/reinforcement-learning/dqn/setup.sh
index 3fcfacbe0a7..3069fef62ec 100755
--- a/example/reinforcement-learning/dqn/setup.sh
+++ b/example/reinforcement-learning/dqn/setup.sh
@@ -22,9 +22,14 @@ set -x
 
 pip install opencv-python
 pip install scipy
+pip install pygame
 
 # Install arcade learning environment
-sudo apt-get install libsdl1.2-dev libsdl-gfx1.2-dev libsdl-image1.2-dev cmake
+if [[ "$OSTYPE" == "linux-gnu" ]]; then
+    sudo apt-get install libsdl1.2-dev libsdl-gfx1.2-dev libsdl-image1.2-dev 
cmake
+elif [[ "$OSTYPE" == "darwin"* ]]; then
+    brew install sdl sdl_image sdl_mixer sdl_ttf portmidi
+fi
 git clone g...@github.com:mgbellemare/Arcade-Learning-Environment.git || true
 pushd .
 cd Arcade-Learning-Environment


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to