-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathiq.py
101 lines (85 loc) · 3.97 KB
/
iq.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
"""
Copyright 2022 Div Garg. All rights reserved.
Standalone IQ-Learn algorithm. See LICENSE for licensing terms.
"""
import torch
import torch.nn.functional as F
# Full IQ-Learn objective with other divergences and options
def iq_loss(agent, current_Q, current_v, next_v, batch):
args = agent.args
gamma = agent.gamma
obs, next_obs, action, env_reward, done, is_expert = batch
loss_dict = {}
# keep track of value of initial states
v0 = agent.getV(obs[is_expert.squeeze(1), ...]).mean()
loss_dict['v0'] = v0.item()
# calculate 1st term for IQ loss
# -E_(ρ_expert)[Q(s, a) - γV(s')]
y = (1 - done) * gamma * next_v
reward = (current_Q - y)[is_expert]
with torch.no_grad():
# Use different divergence functions (For χ2 divergence we instead add a third bellmann error-like term)
if args.method.div == "hellinger":
phi_grad = 1/(1+reward)**2
elif args.method.div == "kl":
# original dual form for kl divergence (sub optimal)
phi_grad = torch.exp(-reward-1)
elif args.method.div == "kl2":
# biased dual form for kl divergence
phi_grad = F.softmax(-reward, dim=0) * reward.shape[0]
elif args.method.div == "kl_fix":
# our proposed unbiased form for fixing kl divergence
phi_grad = torch.exp(-reward)
elif args.method.div == "js":
# jensen–shannon
phi_grad = torch.exp(-reward)/(2 - torch.exp(-reward))
else:
phi_grad = 1
loss = -(phi_grad * reward).mean()
loss_dict['softq_loss'] = loss.item()
# calculate 2nd term for IQ loss, we show different sampling strategies
if args.method.loss == "value_expert":
# sample using only expert states (works offline)
# E_(ρ)[Q(s,a) - γV(s')]
value_loss = (current_v - y)[is_expert].mean()
loss += value_loss
loss_dict['value_loss'] = value_loss.item()
elif args.method.loss == "value":
# sample using expert and policy states (works online)
# E_(ρ)[V(s) - γV(s')]
value_loss = (current_v - y).mean()
loss += value_loss
loss_dict['value_loss'] = value_loss.item()
elif args.method.loss == "v0":
# alternate sampling using only initial states (works offline but usually suboptimal than `value_expert` startegy)
# (1-γ)E_(ρ0)[V(s0)]
v0_loss = (1 - gamma) * v0
loss += v0_loss
loss_dict['v0_loss'] = v0_loss.item()
else:
raise ValueError(f'This sampling method is not implemented: {args.method.type}')
if args.method.grad_pen:
# add a gradient penalty to loss (Wasserstein_1 metric)
gp_loss = agent.critic_net.grad_pen(obs[is_expert.squeeze(1), ...],
action[is_expert.squeeze(1), ...],
obs[~is_expert.squeeze(1), ...],
action[~is_expert.squeeze(1), ...],
args.method.lambda_gp)
loss_dict['gp_loss'] = gp_loss.item()
loss += gp_loss
if args.method.div == "chi" or args.method.chi: # TODO: Deprecate method.chi argument for method.div
# Use χ2 divergence (calculate the regularization term for IQ loss using expert states) (works offline)
y = (1 - done) * gamma * next_v
reward = current_Q - y
chi2_loss = 1/(4 * args.method.alpha) * (reward**2)[is_expert].mean()
loss += chi2_loss
loss_dict['chi2_loss'] = chi2_loss.item()
if args.method.regularize:
# Use χ2 divergence (calculate the regularization term for IQ loss using expert and policy states) (works online)
y = (1 - done) * gamma * next_v
reward = current_Q - y
chi2_loss = 1/(4 * args.method.alpha) * (reward**2).mean()
loss += chi2_loss
loss_dict['regularize_loss'] = chi2_loss.item()
loss_dict['total_loss'] = loss.item()
return loss, loss_dict