Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gymnasium support #94

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
3 changes: 2 additions & 1 deletion backup_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
state = env.reset()
done = False
else:
state, reward, done, info = env.step(env.action_space.sample())
state, reward, terminated, truncated, info = env.step(env.action_space.sample())
done = terminated or truncated
if (i + 1) % 12:
env._backup()
if (i + 1) % 27:
Expand Down
14 changes: 10 additions & 4 deletions nes_py/app/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ def _get_args():
choices=['human', 'random'],
help='The execution mode for the emulation.',
)
parser.add_argument('--seed', '-S',
type=int,
help='the random number seed to use'
)
# add the argument for the number of steps to take in random mode
parser.add_argument('--steps', '-s',
type=int,
Expand All @@ -34,13 +38,15 @@ def main():
"""The main entry point for the command line interface."""
# get arguments from the command line
args = _get_args()
# create the environment
env = NESEnv(args.rom)
# play the environment with the given mode
if args.mode == 'human':
play_human(env)
# environment is initialized without a rendering mode, as play_human creates its own
env = NESEnv(args.rom)
play_human(env, seed=args.seed)
else:
play_random(env, args.steps)
# create the environment
env = NESEnv(args.rom, render_mode='human')
play_random(env, args.steps, seed=args.seed)


# explicitly define the outward facing API of this module
Expand Down
17 changes: 11 additions & 6 deletions nes_py/app/play_human.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""A method to play gym environments using human IO inputs."""
import gym
import gymnasium as gym
import time
from pyglet import clock
from .._image_viewer import ImageViewer
Expand All @@ -9,7 +9,7 @@
_NOP = 0


def play_human(env: gym.Env, callback=None):
def play_human(env: gym.Env, callback=None, seed=None):
"""
Play the environment using keyboard as a human.

Expand Down Expand Up @@ -44,7 +44,11 @@ def play_human(env: gym.Env, callback=None):
relevant_keys=set(sum(map(list, keys_to_action.keys()), []))
)
# create a done flag for the environment
done = True
done = False
# reset the environment with the given seed
state, _ = env.reset(seed=seed)
# render the initial state
viewer.show(env.unwrapped.screen)
# prepare frame rate limiting
target_frame_duration = 1 / env.metadata['video.frames_per_second']
last_frame_time = 0
Expand All @@ -62,15 +66,16 @@ def play_human(env: gym.Env, callback=None):
# reset if the environment is done
if done:
done = False
state = env.reset()
state, _ = env.reset()
viewer.show(env.unwrapped.screen)
# unwrap the action based on pressed relevant keys
action = keys_to_action.get(viewer.pressed_keys, _NOP)
next_state, reward, done, _ = env.step(action)
next_state, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
viewer.show(env.unwrapped.screen)
# pass the observation data through the callback
if callback is not None:
callback(state, action, reward, done, next_state)
callback(state, action, reward, terminated, truncated, next_state)
state = next_state
# shutdown if the escape key is pressed
if viewer.is_escape_pressed:
Expand Down
10 changes: 6 additions & 4 deletions nes_py/app/play_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from tqdm import tqdm


def play_random(env, steps):
def play_random(env, steps, seed=None):
"""
Play the environment making uniformly random decisions.

Expand All @@ -15,13 +15,15 @@ def play_random(env, steps):

"""
try:
done = True
done = False
_, _ = env.reset(seed=seed)
progress = tqdm(range(steps))
for _ in progress:
if done:
_ = env.reset()
_, _ = env.reset()
action = env.action_space.sample()
_, reward, done, info = env.step(action)
_, reward, terminated, truncated, info = env.step(action)
done = terminated or truncated
progress.set_postfix(reward=reward, info=info)
env.render()
except KeyboardInterrupt:
Expand Down
Loading