Python gym 模块,envs() 实例源码
我们从Python开源项目中,提取了以下22个代码示例,用于说明如何使用gym.envs()。
def make_env(env_id, seed, rank, log_dir):
def _thunk():
env = gym.make(env_id)
is_atari = hasattr(gym.envs, 'atari') and isinstance(env.unwrapped, gym.envs.atari.atari_env.AtariEnv)
if is_atari:
env = make_atari(env_id)
env.seed(seed + rank)
if log_dir is not None:
env = bench.Monitor(env, os.path.join(log_dir, str(rank)))
if is_atari:
env = wrap_deepmind(env)
# If the input has shape (W,H,3), wrap for PyTorch convolutions
obs_shape = env.observation_space.shape
if len(obs_shape) == 3 and obs_shape[2] in [1, 3]:
env = WrapPyTorch(env)
return env
return _thunk
def test_default_time_limit():
# We need an env without a default limit
register(
id='test.NoLimitDummyVNCEnv-v0',
entry_point='universe.envs:DummyVNCEnv',
tags={
'vnc': True,
},
)
env = gym.make('test.NoLimitDummyVNCEnv-v0')
env.configure(_n=1)
env = wrappers.TimeLimit(env)
env.reset()
assert env._max_episode_seconds == wrappers.time_limit.DEFAULT_MAX_EPISODE_SECONDS
assert env._max_episode_steps == None
def add_new_rollouts(spec_ids, overwrite):
environments = [spec for spec in envs.registry.all() if spec._entry_point is not None]
if spec_ids:
environments = [spec for spec in environments if spec.id in spec_ids]
assert len(environments) == len(spec_ids), "Some specs not found"
with open(ROLLOUT_FILE) as data_file:
rollout_dict = json.load(data_file)
modified = False
for spec in environments:
if not overwrite and spec.id in rollout_dict:
logger.debug("Rollout already exists for {}. Skipping.".format(spec.id))
else:
modified = update_rollout_dict(spec, rollout_dict) or modified
if modified:
logger.info("Writing new rollout file to {}".format(ROLLOUT_FILE))
with open(ROLLOUT_FILE, "w") as outfile:
json.dump(rollout_dict, outfile, indent=2, sort_keys=True)
else:
logger.info("No modifications needed.")
def should_skip_env_spec_for_tests(spec):
# We skip tests for envs that require dependencies or are otherwise
# troublesome to run frequently
ep = spec._entry_point
# Skip mujoco tests for pull request CI
skip_mujoco = not (os.environ.get('MUJOCO_KEY_BUNDLE') or os.path.exists(os.path.expanduser('~/.mujoco')))
if skip_mujoco and ep.startswith('gym.envs.mujoco:'):
return True
if ( 'GoEnv' in ep or
'HexEnv' in ep or
ep.startswith('gym.envs.box2d:') or
ep.startswith('gym.envs.box2d:') or
ep.startswith('gym.envs.parameter_tuning:') or
ep.startswith('gym.envs.safety:Semisuper') or
(ep.startswith("gym.envs.atari") and not spec.id.startswith("Pong") and not spec.id.startswith("Seaquest"))
):
logger.warning("Skipping tests for env {}".format(ep))
return True
return False
def add_new_rollouts(spec_ids, overwrite):
environments = [spec for spec in envs.registry.all() if spec._entry_point is not None]
if spec_ids:
environments = [spec for spec in environments if spec.id in spec_ids]
assert len(environments) == len(spec_ids), "Some specs not found"
with open(ROLLOUT_FILE) as data_file:
rollout_dict = json.load(data_file)
modified = False
for spec in environments:
if not overwrite and spec.id in rollout_dict:
logger.debug("Rollout already exists for {}. Skipping.".format(spec.id))
else:
modified = update_rollout_dict(spec, rollout_dict) or modified
if modified:
logger.info("Writing new rollout file to {}".format(ROLLOUT_FILE))
with open(ROLLOUT_FILE, "w") as outfile:
json.dump(rollout_dict, outfile, indent=2, sort_keys=True)
else:
logger.info("No modifications needed.")
def add_new_rollouts(spec_ids, overwrite):
environments = [spec for spec in envs.registry.all() if spec._entry_point is not None]
if spec_ids:
environments = [spec for spec in environments if spec.id in spec_ids]
assert len(environments) == len(spec_ids), "Some specs not found"
with open(ROLLOUT_FILE) as data_file:
rollout_dict = json.load(data_file)
modified = False
for spec in environments:
if not overwrite and spec.id in rollout_dict:
logger.debug("Rollout already exists for {}. Skipping.".format(spec.id))
else:
modified = update_rollout_dict(spec, rollout_dict) or modified
if modified:
logger.info("Writing new rollout file to {}".format(ROLLOUT_FILE))
with open(ROLLOUT_FILE, "w") as outfile:
json.dump(rollout_dict, outfile, indent=2, sort_keys=True)
else:
logger.info("No modifications needed.")
def __init__(self, env_name, record_video=True, video_schedule=None, log_dir=None, record_log=True):
if log_dir is None:
if logger.get_snapshot_dir() is None:
logger.log("Warning: skipping Gym environment monitoring since snapshot_dir not configured.")
else:
log_dir = os.path.join(logger.get_snapshot_dir(), "gym_log")
Serializable.quick_init(self, locals())
env = gym.envs.make(env_name)
self.env = env
self.env_id = env.spec.id
monitor.logger.setLevel(logging.WARNING)
assert not (not record_log and record_video)
if log_dir is None or record_log is False:
self.monitoring = False
else:
if not record_video:
video_schedule = NoVideoSchedule()
else:
if video_schedule is None:
video_schedule = CappedCubicVideoSchedule()
self.env.monitor.start(log_dir, video_schedule, force=True) # add 'force=True' if want overwrite dirs
self.monitoring = True
self._observation_space = convert_gym_space(env.observation_space)
self._action_space = convert_gym_space(env.action_space)
self._horizon = env.spec.timestep_limit
self._log_dir = log_dir
def __init__(self, env_name, record_video=True, video_schedule=None, log_dir=None, record_log=True,
force_reset=False):
if log_dir is None:
if logger.get_snapshot_dir() is None:
logger.log("Warning: skipping Gym environment monitoring since snapshot_dir not configured.")
else:
log_dir = os.path.join(logger.get_snapshot_dir(), "gym_log")
Serializable.quick_init(self, locals())
env = gym.envs.make(env_name)
self.env = env
self.env_id = env.spec.id
assert not (not record_log and record_video)
if log_dir is None or record_log is False:
self.monitoring = False
else:
if not record_video:
video_schedule = NoVideoSchedule()
else:
if video_schedule is None:
video_schedule = CappedCubicVideoSchedule()
self.env = gym.wrappers.Monitor(self.env, log_dir, video_callable=video_schedule, force=True)
self.monitoring = True
self._observation_space = convert_gym_space(env.observation_space)
logger.log("observation space: {}".format(self._observation_space))
self._action_space = convert_gym_space(env.action_space)
logger.log("action space: {}".format(self._action_space))
self._horizon = env.spec.tags['wrapper_config.TimeLimit.max_episode_steps']
self._log_dir = log_dir
self._force_reset = force_reset
def test_env_semantics(spec):
with open(ROLLOUT_FILE) as data_file:
rollout_dict = json.load(data_file)
if spec.id not in rollout_dict:
if not spec.nondeterministic:
logger.warn("Rollout does not exist for {}, run generate_json.py to generate rollouts for new envs".format(spec.id))
return
logger.info("Testing rollout for {} environment...".format(spec.id))
observations_now, actions_now, rewards_now, dones_now = generate_rollout_hash(spec)
errors = []
if rollout_dict[spec.id]['observations'] != observations_now:
errors.append('Observations not equal for {} -- expected {} but got {}'.format(spec.id, rollout_dict[spec.id]['observations'], observations_now))
if rollout_dict[spec.id]['actions'] != actions_now:
errors.append('Actions not equal for {} -- expected {} but got {}'.format(spec.id, rollout_dict[spec.id]['actions'], actions_now))
if rollout_dict[spec.id]['rewards'] != rewards_now:
errors.append('Rewards not equal for {} -- expected {} but got {}'.format(spec.id, rollout_dict[spec.id]['rewards'], rewards_now))
if rollout_dict[spec.id]['dones'] != dones_now:
errors.append('Dones not equal for {} -- expected {} but got {}'.format(spec.id, rollout_dict[spec.id]['dones'], dones_now))
if len(errors):
for error in errors:
logger.warn(error)
raise ValueError(errors)
def __init__(self):
self.groups = collections.OrderedDict()
self.envs = collections.OrderedDict()
self.benchmarks = collections.OrderedDict()
def add_group(self, id, name, description, universe=False):
self.groups[id] = {
'id': id,
'name': name,
'description': description,
'envs': [],
'universe': universe,
}
def add_task(self, id, group, summary=None, description=None, background=None, deprecated=False, experimental=False, contributor=None):
self.envs[id] = {
'group': group,
'id': id,
'summary': summary,
'description': description,
'background': background,
'deprecated': deprecated,
'experimental': experimental,
'contributor': contributor,
}
if not deprecated:
self.groups[group]['envs'].append(id)
def test_env_semantics(spec):
with open(ROLLOUT_FILE) as data_file:
rollout_dict = json.load(data_file)
if spec.id not in rollout_dict:
if not spec.nondeterministic:
logger.warn("Rollout does not exist for {}, run generate_json.py to generate rollouts for new envs".format(spec.id))
return
logger.info("Testing rollout for {} environment...".format(spec.id))
observations_now, actions_now, rewards_now, dones_now = generate_rollout_hash(spec)
errors = []
if rollout_dict[spec.id]['observations'] != observations_now:
errors.append('Observations not equal for {} -- expected {} but got {}'.format(spec.id, rollout_dict[spec.id]['observations'], observations_now))
if rollout_dict[spec.id]['actions'] != actions_now:
errors.append('Actions not equal for {} -- expected {} but got {}'.format(spec.id, rollout_dict[spec.id]['actions'], actions_now))
if rollout_dict[spec.id]['rewards'] != rewards_now:
errors.append('Rewards not equal for {} -- expected {} but got {}'.format(spec.id, rollout_dict[spec.id]['rewards'], rewards_now))
if rollout_dict[spec.id]['dones'] != dones_now:
errors.append('Dones not equal for {} -- expected {} but got {}'.format(spec.id, rollout_dict[spec.id]['dones'], dones_now))
if len(errors):
for error in errors:
logger.warn(error)
raise ValueError(errors)
def __init__(self, env_name, record_video=True, video_schedule=None, log_dir=None, record_log=True,
force_reset=False):
if log_dir is None:
if logger.get_snapshot_dir() is None:
logger.log(
"Warning: skipping Gym environment monitoring since snapshot_dir not configured.")
else:
log_dir = os.path.join(logger.get_snapshot_dir(), "gym_log")
Serializable.quick_init(self, locals())
env = gym.envs.make(env_name)
self.env = env
self.env_id = env.spec.id
assert not (not record_log and record_video)
if log_dir is None or record_log is False:
self.monitoring = False
else:
if not record_video:
video_schedule = NoVideoSchedule()
else:
if video_schedule is None:
video_schedule = CappedCubicVideoSchedule()
self.env = gym.wrappers.Monitor(
self.env, log_dir, video_callable=video_schedule, force=True)
self.monitoring = True
self._observation_space = convert_gym_space(env.observation_space)
logger.log("observation space: {}".format(self._observation_space))
self._action_space = convert_gym_space(env.action_space)
logger.log("action space: {}".format(self._action_space))
self._horizon = env.spec.tags['wrapper_config.TimeLimit.max_episode_steps']
self._log_dir = log_dir
self._force_reset = force_reset
def wrap_dqn(env, history_len=4, action_repeat=4, no_op_max=30):
"""
DQN???????????
Parameters
----------
env: gym.envs
gym???
history_len: int
????????????????????
action_repeat: int
1????????????????????????
no_op_max: int
?????????????????????????????
??????????
Returns
----------
env: gym.wrappers.time_limit.TimeLimit
gym.env????????
"""
env = EpisodicLifeEnv(env)
env = NoOpResetEnv(env, no_op_max)
env = MaxAndSkipEnv(env, action_repeat)
if 'FIRE' in env.unwrapped.get_action_meanings():
env = FireResetEnv(env)
env = ProcessFrame84(env)
env = FrameStack(env, history_len)
env = ClippedRewardsWrapper(env)
env = ScaledFloatFrame(env)
return env
def __init__(self, env_name, record_video=True, video_schedule=None, log_dir=None, record_log=True,
force_reset=False):
if log_dir is None:
if logger.get_snapshot_dir() is None:
logger.log("Warning: skipping Gym environment monitoring since snapshot_dir not configured.")
else:
log_dir = os.path.join(logger.get_snapshot_dir(), "gym_log")
Serializable.quick_init(self, locals())
env = gym.envs.make(env_name)
self.env = env
self.env_id = env.spec.id
assert not (not record_log and record_video)
if log_dir is None or record_log is False:
self.monitoring = False
else:
if not record_video:
video_schedule = NoVideoSchedule()
else:
if video_schedule is None:
video_schedule = CappedCubicVideoSchedule()
self.env = gym.wrappers.Monitor(self.env, log_dir, video_callable=video_schedule, force=True)
self.monitoring = True
self._observation_space = convert_gym_space(env.observation_space)
logger.log("observation space: {}".format(self._observation_space))
self._action_space = convert_gym_space(env.action_space)
logger.log("action space: {}".format(self._action_space))
self._horizon = env.spec.tags['wrapper_config.TimeLimit.max_episode_steps']
self._log_dir = log_dir
self._force_reset = force_reset
def __init__(self):
self.groups = collections.OrderedDict()
self.envs = collections.OrderedDict()
self.benchmarks = collections.OrderedDict()
def add_group(self, id, name, description):
self.groups[id] = {
'id': id,
'name': name,
'description': description,
'envs': []
}
def add_task(self, id, group, summary=None, description=None, background=None, deprecated=False, experimental=False, contributor=None):
self.envs[id] = {
'group': group,
'id': id,
'summary': summary,
'description': description,
'background': background,
'deprecated': deprecated,
'experimental': experimental,
'contributor': contributor,
}
if not deprecated:
self.groups[group]['envs'].append(id)
def __init__(self, env_name, record_video=True, video_schedule=None, log_dir=None, record_log=True,
force_reset=False):
if log_dir is None:
if logger.get_snapshot_dir() is None:
logger.log("Warning: skipping Gym environment monitoring since snapshot_dir not configured.")
else:
log_dir = os.path.join(logger.get_snapshot_dir(), "gym_log")
Serializable.quick_init(self, locals())
env = gym.envs.make(env_name)
self.env = env
self.env_id = env.spec.id
monitor_manager.logger.setLevel(logging.WARNING)
assert not (not record_log and record_video)
if log_dir is None or record_log is False:
self.monitoring = False
else:
if not record_video:
video_schedule = NoVideoSchedule()
else:
if video_schedule is None:
video_schedule = CappedCubicVideoSchedule()
self.env = gym.wrappers.Monitor(self.env, log_dir, video_callable=video_schedule, force=True)
self.monitoring = True
self._observation_space = convert_gym_space(env.observation_space)
self._action_space = convert_gym_space(env.action_space)
self._horizon = env.spec.timestep_limit
self._log_dir = log_dir
self._force_reset = force_reset