-
Notifications
You must be signed in to change notification settings - Fork 30
/
Copy pathdataset.py
25 lines (20 loc) · 978 Bytes
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import numpy as np
from torch.utils import data
class Dataset(data.Dataset):
def __init__(self, D, logprobs, entropies, Vs, Qs, As):
self.D = D
self.observations = np.concatenate([np.stack(traj.observations[:-1], 0) for traj in D], 0).astype(np.float32)
self.actions = np.concatenate([np.stack(traj.actions, 0) for traj in D], 0).astype(np.float32)
self.logprobs = logprobs
self.entropies = entropies
self.Vs = Vs
self.Qs = Qs
self.As = As
assert self.observations.shape[0] == len(self) and self.actions.shape[0] == len(self)
assert all([i.shape == (len(self),) for i in [self.logprobs, self.entropies, self.Vs, self.Qs, self.As]])
def __len__(self):
return sum([traj.T for traj in self.D])
def __getitem__(self, i):
D = (self.observations[i], self.actions[i], self.logprobs[i],
self.entropies[i], self.Vs[i], self.Qs[i], self.As[i])
return D