TRPO

Example

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
 from rlzoo.common.env_wrappers import build_env
 from rlzoo.common.utils import call_default_params
 from rlzoo.algorithms import TD3

 AlgName = 'TRPO'
 EnvName = 'PongNoFrameskip-v4'
 EnvType = 'atari'

 # EnvName = 'CartPole-v0'
 # EnvType = 'classic_control'

 # EnvName = 'BipedalWalker-v2'
 # EnvType = 'box2d'

 # EnvName = 'Ant-v2'
 # EnvType = 'mujoco'

 # EnvName = 'FetchPush-v1'
 # EnvType = 'robotics'

 # EnvName = 'FishSwim-v0'
 # EnvType = 'dm_control'

 # EnvName = 'ReachTarget'
 # EnvType = 'rlbench'

 env = build_env(EnvName, EnvType)
 alg_params, learn_params = call_default_params(env, EnvType, AlgName)
 alg = eval(AlgName+'(**alg_params)')
 alg.learn(env=env, mode='train', render=False, **learn_params)
 alg.learn(env=env, mode='test', render=True, **learn_params)

Trust Region Policy Optimization

class rlzoo.algorithms.trpo.trpo.TRPO(net_list, optimizers_list, damping_coeff=0.1, cg_iters=10, delta=0.01)[source]

trpo class

a_train(s, a, adv, oldpi_prob, backtrack_iters, backtrack_coeff)[source]
static assign_params_from_flat(x, params)[source]

assign params from flat input

Parameters:
  • x
  • params
Returns:

group

c_train(tfdc_r, s)[source]

Update actor network

Parameters:
  • tfdc_r – cumulative reward
  • s – state
Returns:

None

cal_adv(tfs, tfdc_r)[source]

Calculate advantage

Parameters:
  • tfs – state
  • tfdc_r – cumulative reward
Returns:

advantage

cg(Ax, b)[source]

Conjugate gradient algorithm (see https://en.wikipedia.org/wiki/Conjugate_gradient_method)

eval(bs, ba, badv, oldpi_prob)[source]
static flat_concat(xs)[source]

flat concat input

Parameters:xs – a list of tensor
Returns:flat tensor
get_action(s)[source]

Choose action

Parameters:s – state
Returns:clipped act
get_action_greedy(s)[source]

Choose action

Parameters:s – state
Returns:clipped act
get_pi_params()[source]

get actor trainable parameters

Returns:flat actor trainable parameters
get_v(s)[source]

Compute value

Parameters:s – state
Returns:value
gradient(inputs)[source]

pi gradients

Parameters:inputs – a list of x_ph, a_ph, adv_ph, ret_ph, logp_old_ph and other inputs
Returns:gradient
hessian_vector_product(s, a, adv, oldpi_prob, v_ph)[source]
learn(env, train_episodes=200, test_episodes=100, max_steps=200, save_interval=10, gamma=0.9, mode='train', render=False, batch_size=32, backtrack_iters=10, backtrack_coeff=0.8, train_critic_iters=80, plot_func=None)[source]

learn function

Parameters:
  • env – learning environment
  • train_episodes – total number of episodes for training
  • test_episodes – total number of episodes for testing
  • max_steps – maximum number of steps for one episode
  • save_interval – time steps for saving
  • gamma – reward discount factor
  • mode – train or test
  • render – render each step
  • batch_size – update batch size
  • backtrack_iters – Maximum number of steps allowed in the backtracking line search
  • backtrack_coeff – How far back to step during backtracking line search
  • train_critic_iters – critic update iteration steps
Returns:

None

load_ckpt(env_name)[source]

load trained weights

Returns:None
pi_loss(inputs)[source]

calculate pi loss

Parameters:inputs – a list of x_ph, a_ph, adv_ph, ret_ph, logp_old_ph and other inputs
Returns:pi loss
save_ckpt(env_name)[source]

save trained weights

Returns:None
set_pi_params(v_ph)[source]

set actor trainable parameters

Parameters:v_ph – inputs
Returns:None
update(bs, ba, br, train_critic_iters, backtrack_iters, backtrack_coeff)[source]

update trpo

Returns:None

Default Hyper-parameters

rlzoo.algorithms.trpo.default.atari(env, default_seed=True)[source]
rlzoo.algorithms.trpo.default.box2d(env, default_seed=True)[source]
rlzoo.algorithms.trpo.default.classic_control(env, default_seed=True)[source]
rlzoo.algorithms.trpo.default.dm_control(env, default_seed=True)[source]
rlzoo.algorithms.trpo.default.mujoco(env, default_seed=True)[source]
rlzoo.algorithms.trpo.default.rlbench(env, default_seed=True)[source]
rlzoo.algorithms.trpo.default.robotics(env, default_seed=True)[source]