Python gym 模块,envs() 实例源码

我们从Python开源项目中,提取了以下22个代码示例,用于说明如何使用gym.envs()

项目:pytorch-a2c-ppo-acktr    作者:ikostrikov    | 项目源码 | 文件源码
def make_env(env_id, seed, rank, log_dir):
    def _thunk():
        env = gym.make(env_id)
        is_atari = hasattr(gym.envs, 'atari') and isinstance(env.unwrapped, gym.envs.atari.atari_env.AtariEnv)
        if is_atari:
            env = make_atari(env_id)
        env.seed(seed + rank)
        if log_dir is not None:
            env = bench.Monitor(env, os.path.join(log_dir, str(rank)))
        if is_atari:
            env = wrap_deepmind(env)
        # If the input has shape (W,H,3), wrap for PyTorch convolutions
        obs_shape = env.observation_space.shape
        if len(obs_shape) == 3 and obs_shape[2] in [1, 3]:
            env = WrapPyTorch(env)
        return env

    return _thunk
项目:universe    作者:openai    | 项目源码 | 文件源码
def test_default_time_limit():
    # We need an env without a default limit
    register(
        id='test.NoLimitDummyVNCEnv-v0',
        entry_point='universe.envs:DummyVNCEnv',
        tags={
            'vnc': True,
            },
    )

    env = gym.make('test.NoLimitDummyVNCEnv-v0')
    env.configure(_n=1)
    env = wrappers.TimeLimit(env)
    env.reset()

    assert env._max_episode_seconds == wrappers.time_limit.DEFAULT_MAX_EPISODE_SECONDS
    assert env._max_episode_steps == None
项目:gym    作者:openai    | 项目源码 | 文件源码
def add_new_rollouts(spec_ids, overwrite):
  environments = [spec for spec in envs.registry.all() if spec._entry_point is not None]
  if spec_ids:
    environments = [spec for spec in environments if spec.id in spec_ids]
    assert len(environments) == len(spec_ids), "Some specs not found"
  with open(ROLLOUT_FILE) as data_file:
    rollout_dict = json.load(data_file)
  modified = False
  for spec in environments:
    if not overwrite and spec.id in rollout_dict:
      logger.debug("Rollout already exists for {}. Skipping.".format(spec.id))
    else:
      modified = update_rollout_dict(spec, rollout_dict) or modified

  if modified:
    logger.info("Writing new rollout file to {}".format(ROLLOUT_FILE))
    with open(ROLLOUT_FILE, "w") as outfile:
      json.dump(rollout_dict, outfile, indent=2, sort_keys=True)
  else:
    logger.info("No modifications needed.")
项目:gym    作者:openai    | 项目源码 | 文件源码
def should_skip_env_spec_for_tests(spec):
    # We skip tests for envs that require dependencies or are otherwise
    # troublesome to run frequently
    ep = spec._entry_point
    # Skip mujoco tests for pull request CI
    skip_mujoco = not (os.environ.get('MUJOCO_KEY_BUNDLE') or os.path.exists(os.path.expanduser('~/.mujoco')))
    if skip_mujoco and ep.startswith('gym.envs.mujoco:'):
        return True
    if (    'GoEnv' in ep or
            'HexEnv' in ep or
            ep.startswith('gym.envs.box2d:') or
            ep.startswith('gym.envs.box2d:') or
            ep.startswith('gym.envs.parameter_tuning:') or
            ep.startswith('gym.envs.safety:Semisuper') or
            (ep.startswith("gym.envs.atari") and not spec.id.startswith("Pong") and not spec.id.startswith("Seaquest"))
    ):
        logger.warning("Skipping tests for env {}".format(ep))
        return True
    return False
项目:AI-Fight-the-Landlord    作者:YoungGer    | 项目源码 | 文件源码
def add_new_rollouts(spec_ids, overwrite):
  environments = [spec for spec in envs.registry.all() if spec._entry_point is not None]
  if spec_ids:
    environments = [spec for spec in environments if spec.id in spec_ids]
    assert len(environments) == len(spec_ids), "Some specs not found"
  with open(ROLLOUT_FILE) as data_file:
    rollout_dict = json.load(data_file)
  modified = False
  for spec in environments:
    if not overwrite and spec.id in rollout_dict:
      logger.debug("Rollout already exists for {}. Skipping.".format(spec.id))
    else:
      modified = update_rollout_dict(spec, rollout_dict) or modified

  if modified:
    logger.info("Writing new rollout file to {}".format(ROLLOUT_FILE))
    with open(ROLLOUT_FILE, "w") as outfile:
      json.dump(rollout_dict, outfile, indent=2, sort_keys=True)
  else:
    logger.info("No modifications needed.")
项目:gym-adv    作者:lerrel    | 项目源码 | 文件源码
def add_new_rollouts(spec_ids, overwrite):
  environments = [spec for spec in envs.registry.all() if spec._entry_point is not None]
  if spec_ids:
    environments = [spec for spec in environments if spec.id in spec_ids]
    assert len(environments) == len(spec_ids), "Some specs not found"
  with open(ROLLOUT_FILE) as data_file:
    rollout_dict = json.load(data_file)
  modified = False
  for spec in environments:
    if not overwrite and spec.id in rollout_dict:
      logger.debug("Rollout already exists for {}. Skipping.".format(spec.id))
    else:
      modified = update_rollout_dict(spec, rollout_dict) or modified

  if modified:
    logger.info("Writing new rollout file to {}".format(ROLLOUT_FILE))
    with open(ROLLOUT_FILE, "w") as outfile:
      json.dump(rollout_dict, outfile, indent=2, sort_keys=True)
  else:
    logger.info("No modifications needed.")
项目:third_person_im    作者:bstadie    | 项目源码 | 文件源码
def __init__(self, env_name, record_video=True, video_schedule=None, log_dir=None, record_log=True):
        if log_dir is None:
            if logger.get_snapshot_dir() is None:
                logger.log("Warning: skipping Gym environment monitoring since snapshot_dir not configured.")
            else:
                log_dir = os.path.join(logger.get_snapshot_dir(), "gym_log")
        Serializable.quick_init(self, locals())

        env = gym.envs.make(env_name)
        self.env = env
        self.env_id = env.spec.id

        monitor.logger.setLevel(logging.WARNING)

        assert not (not record_log and record_video)

        if log_dir is None or record_log is False:
            self.monitoring = False
        else:
            if not record_video:
                video_schedule = NoVideoSchedule()
            else:
                if video_schedule is None:
                    video_schedule = CappedCubicVideoSchedule()
            self.env.monitor.start(log_dir, video_schedule, force=True)  # add 'force=True' if want overwrite dirs
            self.monitoring = True

        self._observation_space = convert_gym_space(env.observation_space)
        self._action_space = convert_gym_space(env.action_space)
        self._horizon = env.spec.timestep_limit
        self._log_dir = log_dir
项目:rllabplusplus    作者:shaneshixiang    | 项目源码 | 文件源码
def __init__(self, env_name, record_video=True, video_schedule=None, log_dir=None, record_log=True,
                 force_reset=False):
        if log_dir is None:
            if logger.get_snapshot_dir() is None:
                logger.log("Warning: skipping Gym environment monitoring since snapshot_dir not configured.")
            else:
                log_dir = os.path.join(logger.get_snapshot_dir(), "gym_log")
        Serializable.quick_init(self, locals())

        env = gym.envs.make(env_name)
        self.env = env
        self.env_id = env.spec.id

        assert not (not record_log and record_video)

        if log_dir is None or record_log is False:
            self.monitoring = False
        else:
            if not record_video:
                video_schedule = NoVideoSchedule()
            else:
                if video_schedule is None:
                    video_schedule = CappedCubicVideoSchedule()
            self.env = gym.wrappers.Monitor(self.env, log_dir, video_callable=video_schedule, force=True)
            self.monitoring = True

        self._observation_space = convert_gym_space(env.observation_space)
        logger.log("observation space: {}".format(self._observation_space))
        self._action_space = convert_gym_space(env.action_space)
        logger.log("action space: {}".format(self._action_space))
        self._horizon = env.spec.tags['wrapper_config.TimeLimit.max_episode_steps']
        self._log_dir = log_dir
        self._force_reset = force_reset
项目:gym    作者:openai    | 项目源码 | 文件源码
def test_env_semantics(spec):
  with open(ROLLOUT_FILE) as data_file:
    rollout_dict = json.load(data_file)

  if spec.id not in rollout_dict:
    if not spec.nondeterministic:
      logger.warn("Rollout does not exist for {}, run generate_json.py to generate rollouts for new envs".format(spec.id))
    return

  logger.info("Testing rollout for {} environment...".format(spec.id))

  observations_now, actions_now, rewards_now, dones_now = generate_rollout_hash(spec)

  errors = []
  if rollout_dict[spec.id]['observations'] != observations_now:
    errors.append('Observations not equal for {} -- expected {} but got {}'.format(spec.id, rollout_dict[spec.id]['observations'], observations_now))
  if rollout_dict[spec.id]['actions'] != actions_now:
    errors.append('Actions not equal for {} -- expected {} but got {}'.format(spec.id, rollout_dict[spec.id]['actions'], actions_now))
  if rollout_dict[spec.id]['rewards'] != rewards_now:
    errors.append('Rewards not equal for {} -- expected {} but got {}'.format(spec.id, rollout_dict[spec.id]['rewards'], rewards_now))
  if rollout_dict[spec.id]['dones'] != dones_now:
    errors.append('Dones not equal for {} -- expected {} but got {}'.format(spec.id, rollout_dict[spec.id]['dones'], dones_now))
  if len(errors):
    for error in errors:
      logger.warn(error)
    raise ValueError(errors)
项目:AI-Fight-the-Landlord    作者:YoungGer    | 项目源码 | 文件源码
def __init__(self):
        self.groups = collections.OrderedDict()
        self.envs = collections.OrderedDict()
        self.benchmarks = collections.OrderedDict()
项目:AI-Fight-the-Landlord    作者:YoungGer    | 项目源码 | 文件源码
def env(self, id):
        return self.envs[id]
项目:AI-Fight-the-Landlord    作者:YoungGer    | 项目源码 | 文件源码
def add_group(self, id, name, description, universe=False):
        self.groups[id] = {
            'id': id,
            'name': name,
            'description': description,
            'envs': [],
            'universe': universe,
        }
项目:AI-Fight-the-Landlord    作者:YoungGer    | 项目源码 | 文件源码
def add_task(self, id, group, summary=None, description=None, background=None, deprecated=False, experimental=False, contributor=None):
        self.envs[id] = {
            'group': group,
            'id': id,
            'summary': summary,
            'description': description,
            'background': background,
            'deprecated': deprecated,
            'experimental': experimental,
            'contributor': contributor,
        }
        if not deprecated:
            self.groups[group]['envs'].append(id)
项目:AI-Fight-the-Landlord    作者:YoungGer    | 项目源码 | 文件源码
def test_env_semantics(spec):
  with open(ROLLOUT_FILE) as data_file:
    rollout_dict = json.load(data_file)

  if spec.id not in rollout_dict:
    if not spec.nondeterministic:
      logger.warn("Rollout does not exist for {}, run generate_json.py to generate rollouts for new envs".format(spec.id))
    return

  logger.info("Testing rollout for {} environment...".format(spec.id))

  observations_now, actions_now, rewards_now, dones_now = generate_rollout_hash(spec)

  errors = []
  if rollout_dict[spec.id]['observations'] != observations_now:
    errors.append('Observations not equal for {} -- expected {} but got {}'.format(spec.id, rollout_dict[spec.id]['observations'], observations_now))
  if rollout_dict[spec.id]['actions'] != actions_now:
    errors.append('Actions not equal for {} -- expected {} but got {}'.format(spec.id, rollout_dict[spec.id]['actions'], actions_now))
  if rollout_dict[spec.id]['rewards'] != rewards_now:
    errors.append('Rewards not equal for {} -- expected {} but got {}'.format(spec.id, rollout_dict[spec.id]['rewards'], rewards_now))
  if rollout_dict[spec.id]['dones'] != dones_now:
    errors.append('Dones not equal for {} -- expected {} but got {}'.format(spec.id, rollout_dict[spec.id]['dones'], dones_now))
  if len(errors):
    for error in errors:
      logger.warn(error)
    raise ValueError(errors)
项目:gail-driver    作者:sisl    | 项目源码 | 文件源码
def __init__(self, env_name, record_video=True, video_schedule=None, log_dir=None, record_log=True,
                 force_reset=False):
        if log_dir is None:
            if logger.get_snapshot_dir() is None:
                logger.log(
                    "Warning: skipping Gym environment monitoring since snapshot_dir not configured.")
            else:
                log_dir = os.path.join(logger.get_snapshot_dir(), "gym_log")
        Serializable.quick_init(self, locals())

        env = gym.envs.make(env_name)
        self.env = env
        self.env_id = env.spec.id

        assert not (not record_log and record_video)

        if log_dir is None or record_log is False:
            self.monitoring = False
        else:
            if not record_video:
                video_schedule = NoVideoSchedule()
            else:
                if video_schedule is None:
                    video_schedule = CappedCubicVideoSchedule()
            self.env = gym.wrappers.Monitor(
                self.env, log_dir, video_callable=video_schedule, force=True)
            self.monitoring = True

        self._observation_space = convert_gym_space(env.observation_space)
        logger.log("observation space: {}".format(self._observation_space))
        self._action_space = convert_gym_space(env.action_space)
        logger.log("action space: {}".format(self._action_space))
        self._horizon = env.spec.tags['wrapper_config.TimeLimit.max_episode_steps']
        self._log_dir = log_dir
        self._force_reset = force_reset
项目:DeepQNetwork    作者:bakanaouji    | 项目源码 | 文件源码
def wrap_dqn(env, history_len=4, action_repeat=4, no_op_max=30):
    """
    DQN???????????

    Parameters
    ----------
    env: gym.envs
        gym???
    history_len: int
        ????????????????????
    action_repeat: int
        1????????????????????????
    no_op_max: int
        ?????????????????????????????
        ??????????

    Returns
    ----------
    env: gym.wrappers.time_limit.TimeLimit
        gym.env????????
    """
    env = EpisodicLifeEnv(env)
    env = NoOpResetEnv(env, no_op_max)
    env = MaxAndSkipEnv(env, action_repeat)
    if 'FIRE' in env.unwrapped.get_action_meanings():
        env = FireResetEnv(env)
    env = ProcessFrame84(env)
    env = FrameStack(env, history_len)
    env = ClippedRewardsWrapper(env)
    env = ScaledFloatFrame(env)
    return env
项目:rllab    作者:rll    | 项目源码 | 文件源码
def __init__(self, env_name, record_video=True, video_schedule=None, log_dir=None, record_log=True,
                 force_reset=False):
        if log_dir is None:
            if logger.get_snapshot_dir() is None:
                logger.log("Warning: skipping Gym environment monitoring since snapshot_dir not configured.")
            else:
                log_dir = os.path.join(logger.get_snapshot_dir(), "gym_log")
        Serializable.quick_init(self, locals())

        env = gym.envs.make(env_name)
        self.env = env
        self.env_id = env.spec.id

        assert not (not record_log and record_video)

        if log_dir is None or record_log is False:
            self.monitoring = False
        else:
            if not record_video:
                video_schedule = NoVideoSchedule()
            else:
                if video_schedule is None:
                    video_schedule = CappedCubicVideoSchedule()
            self.env = gym.wrappers.Monitor(self.env, log_dir, video_callable=video_schedule, force=True)
            self.monitoring = True

        self._observation_space = convert_gym_space(env.observation_space)
        logger.log("observation space: {}".format(self._observation_space))
        self._action_space = convert_gym_space(env.action_space)
        logger.log("action space: {}".format(self._action_space))
        self._horizon = env.spec.tags['wrapper_config.TimeLimit.max_episode_steps']
        self._log_dir = log_dir
        self._force_reset = force_reset
项目:gym-adv    作者:lerrel    | 项目源码 | 文件源码
def __init__(self):
        self.groups = collections.OrderedDict()
        self.envs = collections.OrderedDict()
        self.benchmarks = collections.OrderedDict()
项目:gym-adv    作者:lerrel    | 项目源码 | 文件源码
def env(self, id):
        return self.envs[id]
项目:gym-adv    作者:lerrel    | 项目源码 | 文件源码
def add_group(self, id, name, description):
        self.groups[id] = {
            'id': id,
            'name': name,
            'description': description,
            'envs': []
        }
项目:gym-adv    作者:lerrel    | 项目源码 | 文件源码
def add_task(self, id, group, summary=None, description=None, background=None, deprecated=False, experimental=False, contributor=None):
        self.envs[id] = {
            'group': group,
            'id': id,
            'summary': summary,
            'description': description,
            'background': background,
            'deprecated': deprecated,
            'experimental': experimental,
            'contributor': contributor,
        }
        if not deprecated:
            self.groups[group]['envs'].append(id)
项目:maml_rl    作者:cbfinn    | 项目源码 | 文件源码
def __init__(self, env_name, record_video=True, video_schedule=None, log_dir=None, record_log=True,
                 force_reset=False):
        if log_dir is None:
            if logger.get_snapshot_dir() is None:
                logger.log("Warning: skipping Gym environment monitoring since snapshot_dir not configured.")
            else:
                log_dir = os.path.join(logger.get_snapshot_dir(), "gym_log")
        Serializable.quick_init(self, locals())

        env = gym.envs.make(env_name)
        self.env = env
        self.env_id = env.spec.id

        monitor_manager.logger.setLevel(logging.WARNING)

        assert not (not record_log and record_video)

        if log_dir is None or record_log is False:
            self.monitoring = False
        else:
            if not record_video:
                video_schedule = NoVideoSchedule()
            else:
                if video_schedule is None:
                    video_schedule = CappedCubicVideoSchedule()
            self.env = gym.wrappers.Monitor(self.env, log_dir, video_callable=video_schedule, force=True)
            self.monitoring = True

        self._observation_space = convert_gym_space(env.observation_space)
        self._action_space = convert_gym_space(env.action_space)
        self._horizon = env.spec.timestep_limit
        self._log_dir = log_dir
        self._force_reset = force_reset