DQN and Variants

Example

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
 from rlzoo.common.env_wrappers import build_env
 from rlzoo.common.utils import call_default_params
 from rlzoo.algorithms import DQN

 AlgName = 'DQN'
 EnvName = 'PongNoFrameskip-v4'
 EnvType = 'atari'

 # EnvName = 'CartPole-v1'
 # EnvType = 'classic_control'  # the name of env needs to match the type of env

 env = build_env(EnvName, EnvType)
 alg_params, learn_params = call_default_params(env, EnvType, AlgName)
 alg = eval(AlgName+'(**alg_params)')
 alg.learn(env=env, mode='train', **learn_params)
 alg.learn(env=env, mode='test', render=True, **learn_params)

Deep Q-Networks

class rlzoo.algorithms.dqn.dqn.DQN(net_list, optimizers_list, double_q, dueling, buffer_size, prioritized_replay, prioritized_alpha, prioritized_beta0)[source]

Papers:

Mnih V, Kavukcuoglu K, Silver D, et al. Human-level control through deep reinforcement learning[J]. Nature, 2015, 518(7540): 529.

Hessel M, Modayil J, Van Hasselt H, et al. Rainbow: Combining Improvements in Deep Reinforcement Learning[J]. 2017.

get_action(obv, eps=0.2)[source]
get_action_greedy(obv)[source]
learn(env, mode='train', render=False, train_episodes=1000, test_episodes=10, max_steps=200, save_interval=1000, gamma=0.99, exploration_rate=0.2, exploration_final_eps=0.01, target_network_update_freq=50, batch_size=32, train_freq=4, learning_starts=200, plot_func=None)[source]
Parameters:
  • env – learning environment
  • mode – train or test
  • render – render each step
  • train_episodes – total number of episodes for training
  • test_episodes – total number of episodes for testing
  • max_steps – maximum number of steps for one episode
  • save_interval – time steps for saving
  • gamma – reward decay factor
  • (float) (exploration_final_eps) – fraction of entire training period over which the exploration rate is annealed
  • (float) – final value of random action probability
  • (int) (learning_starts) – update the target network every target_network_update_freq steps
  • (int) – size of a batched sampled from replay buffer for training
  • (int) – update the model every train_freq steps
  • (int) – how many steps of the model to collect transitions for before learning starts
  • plot_func – additional function for interactive module
load_ckpt(env_name)[source]

load trained weights :return: None

save_ckpt(env_name)[source]

save trained weights :return: None

store_transition(s, a, r, s_, d)[source]
sync()[source]

Copy q network to target q network

update(batch_size, gamma)[source]

Default Hyper-parameters

rlzoo.algorithms.dqn.default.atari(env, default_seed=False, **kwargs)[source]
rlzoo.algorithms.dqn.default.classic_control(env, default_seed=False, **kwargs)[source]