-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathreplay_buffer.py
executable file
·118 lines (96 loc) · 5.48 KB
/
replay_buffer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import numpy as np
import torch
import utils
class ReplayBuffer(object):
"""Buffer to store environment transitions."""
def __init__(self, obs_shape, action_shape, capacity, device, window=1):
self.capacity = capacity
self.device = device
# the proprioceptive obs is stored as float32, pixels obs as uint8
obs_dtype = np.float32 if len(obs_shape) == 1 else np.uint8
self.obses = np.empty((capacity, *obs_shape), dtype=obs_dtype)
self.next_obses = np.empty((capacity, *obs_shape), dtype=obs_dtype)
self.actions = np.empty((capacity, *action_shape), dtype=np.float32)
self.rewards = np.empty((capacity, 1), dtype=np.float32)
self.not_dones = np.empty((capacity, 1), dtype=np.float32)
self.not_dones_no_max = np.empty((capacity, 1), dtype=np.float32)
self.window = window
self.idx = 0
self.last_save = 0
self.full = False
def __len__(self):
return self.capacity if self.full else self.idx
def add(self, obs, action, reward, next_obs, done, done_no_max):
np.copyto(self.obses[self.idx], obs)
np.copyto(self.actions[self.idx], action)
np.copyto(self.rewards[self.idx], reward)
np.copyto(self.next_obses[self.idx], next_obs)
np.copyto(self.not_dones[self.idx], not done)
np.copyto(self.not_dones_no_max[self.idx], not done_no_max)
self.idx = (self.idx + 1) % self.capacity
self.full = self.full or self.idx == 0
def add_batch(self, obs, action, reward, next_obs, done, done_no_max):
next_index = self.idx + self.window
if next_index >= self.capacity:
self.full = True
maximum_index = self.capacity - self.idx
np.copyto(self.obses[self.idx:self.capacity], obs[:maximum_index])
np.copyto(self.actions[self.idx:self.capacity], action[:maximum_index])
np.copyto(self.rewards[self.idx:self.capacity], reward[:maximum_index])
np.copyto(self.next_obses[self.idx:self.capacity], next_obs[:maximum_index])
np.copyto(self.not_dones[self.idx:self.capacity], done[:maximum_index] <= 0)
np.copyto(self.not_dones_no_max[self.idx:self.capacity], done_no_max[:maximum_index] <= 0)
remain = self.window - (maximum_index)
if remain > 0:
np.copyto(self.obses[0:remain], obs[maximum_index:])
np.copyto(self.actions[0:remain], action[maximum_index:])
np.copyto(self.rewards[0:remain], reward[maximum_index:])
np.copyto(self.next_obses[0:remain], next_obs[maximum_index:])
np.copyto(self.not_dones[0:remain], done[maximum_index:] <= 0)
np.copyto(self.not_dones_no_max[0:remain], done_no_max[maximum_index:] <= 0)
self.idx = remain
else:
np.copyto(self.obses[self.idx:next_index], obs)
np.copyto(self.actions[self.idx:next_index], action)
np.copyto(self.rewards[self.idx:next_index], reward)
np.copyto(self.next_obses[self.idx:next_index], next_obs)
np.copyto(self.not_dones[self.idx:next_index], done <= 0)
np.copyto(self.not_dones_no_max[self.idx:next_index], done_no_max <= 0)
self.idx = next_index
def relabel_with_predictor(self, predictor):
batch_size = 200
total_iter = int(self.idx / batch_size)
if self.idx > batch_size * total_iter:
total_iter += 1
for index in range(total_iter):
last_index = (index + 1) * batch_size
if (index + 1) * batch_size > self.idx:
last_index = self.idx
obses = self.obses[index * batch_size:last_index]
actions = self.actions[index * batch_size:last_index]
inputs = np.concatenate([obses, actions], axis=-1)
pred_reward = predictor.r_hat_batch(inputs)
self.rewards[index * batch_size:last_index] = pred_reward
def sample(self, batch_size):
idxs = np.random.randint(0, self.capacity if self.full else self.idx, size=batch_size)
obses = torch.as_tensor(self.obses[idxs], device=self.device).float()
actions = torch.as_tensor(self.actions[idxs], device=self.device)
rewards = torch.as_tensor(self.rewards[idxs], device=self.device)
next_obses = torch.as_tensor(self.next_obses[idxs], device=self.device).float()
not_dones = torch.as_tensor(self.not_dones[idxs], device=self.device)
not_dones_no_max = torch.as_tensor(self.not_dones_no_max[idxs], device=self.device)
return obses, actions, rewards, next_obses, not_dones, not_dones_no_max
def sample_state_ent(self, batch_size):
idxs = np.random.randint(0, self.capacity if self.full else self.idx, size=batch_size)
obses = torch.as_tensor(self.obses[idxs], device=self.device).float()
actions = torch.as_tensor(self.actions[idxs], device=self.device)
rewards = torch.as_tensor(self.rewards[idxs], device=self.device)
next_obses = torch.as_tensor(self.next_obses[idxs], device=self.device).float()
not_dones = torch.as_tensor(self.not_dones[idxs], device=self.device)
not_dones_no_max = torch.as_tensor(self.not_dones_no_max[idxs], device=self.device)
if self.full:
full_obs = self.obses
else:
full_obs = self.obses[: self.idx]
full_obs = torch.as_tensor(full_obs, device=self.device)
return obses, full_obs, actions, rewards, next_obses, not_dones, not_dones_no_max