piiswrong closed pull request #11051: Fix DQN example URL: https://github.com/apache/incubator-mxnet/pull/11051
This is a PR merged from a forked repository. As GitHub hides the original diff on merge, it is displayed below for the sake of provenance: As this is a foreign pull request (from a fork), the diff is supplied below (as it won't show otherwise due to GitHub magic): diff --git a/example/reinforcement-learning/dqn/README.md b/example/reinforcement-learning/dqn/README.md index fd32667a1f8..4547904b595 100644 Binary files a/example/reinforcement-learning/dqn/README.md and b/example/reinforcement-learning/dqn/README.md differ diff --git a/example/reinforcement-learning/dqn/dqn_run_test.py b/example/reinforcement-learning/dqn/dqn_run_test.py old mode 100644 new mode 100755 index 2abf273978f..e8f36b97976 --- a/example/reinforcement-learning/dqn/dqn_run_test.py +++ b/example/reinforcement-learning/dqn/dqn_run_test.py @@ -1,3 +1,5 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -89,8 +91,8 @@ def calculate_avg_reward(game, qnet, test_steps=125000, exploartion=0.05): current_state = game.current_state() state = nd.array(current_state.reshape((1,) + current_state.shape), ctx=qnet.ctx) / float(255.0) - action = nd.argmax_channel( - qnet.forward(is_train=False, data=state)[0]).asscalar() + action = int(nd.argmax_channel( + qnet.forward(is_train=False, data=state)[0]).asscalar()) else: action = npy_rng.randint(action_num) @@ -120,7 +122,7 @@ def main(): help='Running Context. E.g `-c gpu` or `-c gpu1` or `-c cpu`') parser.add_argument('-e', '--epoch-range', required=False, type=str, default='22', help='Epochs to run testing. E.g `-e 0,80`, `-e 0,80,2`') - parser.add_argument('-v', '--visualization', required=False, type=int, default=0, + parser.add_argument('-v', '--visualization', action='store_true', help='Visualize the runs.') parser.add_argument('--symbol', required=False, type=str, default="nature", help='type of network, nature or nips') diff --git a/example/reinforcement-learning/dqn/setup.sh b/example/reinforcement-learning/dqn/setup.sh index 3fcfacbe0a7..3069fef62ec 100755 --- a/example/reinforcement-learning/dqn/setup.sh +++ b/example/reinforcement-learning/dqn/setup.sh @@ -22,9 +22,14 @@ set -x pip install opencv-python pip install scipy +pip install pygame # Install arcade learning environment -sudo apt-get install libsdl1.2-dev libsdl-gfx1.2-dev libsdl-image1.2-dev cmake +if [[ "$OSTYPE" == "linux-gnu" ]]; then + sudo apt-get install libsdl1.2-dev libsdl-gfx1.2-dev libsdl-image1.2-dev cmake +elif [[ "$OSTYPE" == "darwin"* ]]; then + brew install sdl sdl_image sdl_mixer sdl_ttf portmidi +fi git clone g...@github.com:mgbellemare/Arcade-Learning-Environment.git || true pushd . cd Arcade-Learning-Environment ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services