Skip to content

Commit 4f83f52

Browse files
committed
Merge branch 'refs/heads/dev'
2 parents bc7db93 + 15b5df7 commit 4f83f52

File tree

7 files changed

+64
-37
lines changed

7 files changed

+64
-37
lines changed

exllamav2/architecture.py

+2
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,8 @@ class Params:
402402
self.mmp.mlp_bias = True
403403
self.mmp.norm = "layernorm"
404404

405+
self.standard_calib_noise = (5, 30)
406+
405407
# Gemma
406408

407409
if arch_string == "GemmaForCausalLM":

exllamav2/embedding.py

+32-8
Original file line numberDiff line numberDiff line change
@@ -186,16 +186,40 @@ def forward(
186186
if self.archparams.normalize_embeddings:
187187
hidden_states *= cfg.hidden_size ** 0.5
188188

189-
# Negative tokens during quantization are noise tokens
189+
# Rows with negative tokens during quantization are noise tokens
190190

191191
if kwargs.get("negative_ids_noise"):
192-
mask = (input_ids < 0).unsqueeze(-1)
193-
unmasked_values = hidden_states[~mask.expand_as(hidden_states)].float()
194-
mean, std = unmasked_values.mean(), unmasked_values.std()
195-
noise = torch.randn_like(hidden_states, dtype = torch.float)
196-
noise = noise * std + mean
197-
noise = noise.half()
198-
hidden_states = torch.where(mask, noise, hidden_states)
192+
193+
n = 0
194+
mean = torch.tensor([0.0], dtype = torch.float, device = hidden_states.device)
195+
M2 = torch.tensor([0.0], dtype = torch.float, device = hidden_states.device)
196+
197+
for i in range(input_ids.shape[0]):
198+
if input_ids[i][0] < 0:
199+
continue
200+
201+
er = hidden_states[i].float()
202+
n += er.numel()
203+
delta = er - mean
204+
mean += delta.sum() / n
205+
delta2 = er - mean
206+
M2 += (delta * delta2).sum()
207+
del er
208+
del delta
209+
del delta2
210+
211+
if n > 1:
212+
std = torch.sqrt(M2 / (n - 1))
213+
214+
for i in range(input_ids.shape[0]):
215+
if input_ids[i][0] >= 0:
216+
continue
217+
218+
er = hidden_states[i]
219+
noise = torch.randn(er.size(), dtype = torch.float, device = hidden_states.device) * std + mean
220+
er.copy_(noise.half())
221+
del er
222+
del noise
199223

200224
# Move to pinned temp buffer for TP
201225

exllamav2/exllamav2_ext/ext_rope.cpp

+21-21
Original file line numberDiff line numberDiff line change
@@ -58,50 +58,50 @@ void rope_
5858
);
5959
}
6060

61-
long gen_mrope_pos_ids
61+
int64_t gen_mrope_pos_ids
6262
(
6363
torch::Tensor mrope_pos_ids,
6464
torch::Tensor ids,
6565
int merge_size,
66-
const std::vector<std::tuple<long, long>> &spans,
67-
const std::vector<std::tuple<long, long, long>> &grids
66+
const std::vector<std::tuple<int64_t, int64_t>> &spans,
67+
const std::vector<std::tuple<int64_t, int64_t, int64_t>> &grids
6868
)
6969
{
7070
int max_length = mrope_pos_ids.size(1);
7171
int in_length = ids.size(0);
7272

73-
long* in_ids = (long*) ids.data_ptr();
74-
long* pos_ids = (long*) mrope_pos_ids.data_ptr();
73+
int64_t* in_ids = (int64_t*) ids.data_ptr();
74+
int64_t* pos_ids = (int64_t*) mrope_pos_ids.data_ptr();
7575

76-
long* out_t = pos_ids;
77-
long* out_h = pos_ids + max_length;
78-
long* out_w = pos_ids + 2 * max_length;
76+
int64_t* out_t = pos_ids;
77+
int64_t* out_h = pos_ids + max_length;
78+
int64_t* out_w = pos_ids + 2 * max_length;
7979

80-
long base_t = 0;
81-
long next_base_t = 0;
80+
int64_t base_t = 0;
81+
int64_t next_base_t = 0;
8282

8383
for (int i = 0; i < max_length; ++i)
8484
{
8585
bool is_emb = false;
8686
if (i < in_length)
8787
{
88-
long id = in_ids[i];
88+
int64_t id = in_ids[i];
8989

9090
for (int j = 0; j < spans.size(); ++j)
9191
{
92-
long span_start = std::get<0>(spans[j]);
93-
long span_end = std::get<1>(spans[j]);
94-
long span = span_end - span_start;
92+
int64_t span_start = std::get<0>(spans[j]);
93+
int64_t span_end = std::get<1>(spans[j]);
94+
int64_t span = span_end - span_start;
9595
if (id >= span_start && id < span_end)
9696
{
9797
is_emb = true;
98-
long k = id - span_start;
99-
long grid_t = std::get<0>(grids[j]);
100-
long grid_h = std::get<1>(grids[j]) / (long)merge_size;
101-
long grid_w = std::get<2>(grids[j]) / (long)merge_size;
102-
long k_t = base_t + (k / grid_w / grid_h) % grid_t;
103-
long k_h = base_t + (k / grid_w) % grid_h;
104-
long k_w = base_t + k % grid_w;
98+
int64_t k = id - span_start;
99+
int64_t grid_t = std::get<0>(grids[j]);
100+
int64_t grid_h = std::get<1>(grids[j]) / (int64_t)merge_size;
101+
int64_t grid_w = std::get<2>(grids[j]) / (int64_t)merge_size;
102+
int64_t k_t = base_t + (k / grid_w / grid_h) % grid_t;
103+
int64_t k_h = base_t + (k / grid_w) % grid_h;
104+
int64_t k_w = base_t + k % grid_w;
105105
*out_t++ = k_t;
106106
*out_h++ = k_h;
107107
*out_w++ = k_w;

exllamav2/exllamav2_ext/ext_rope.h

+3-3
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@ void rope_
1111
bool neox_style
1212
);
1313

14-
long gen_mrope_pos_ids
14+
int64_t gen_mrope_pos_ids
1515
(
1616
torch::Tensor mrope_pos_ids,
1717
torch::Tensor ids,
1818
int merge_size,
19-
const std::vector<std::tuple<long, long>> &spans,
20-
const std::vector<std::tuple<long, long, long>> &grids
19+
const std::vector<std::tuple<int64_t, int64_t>> &spans,
20+
const std::vector<std::tuple<int64_t, int64_t, int64_t>> &grids
2121
);

exllamav2/generator/dynamic.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -2589,8 +2589,9 @@ def deallocate_pages(self):
25892589
self.generator.all_pages[0].backup()
25902590

25912591
for seq in self.sequences:
2592-
for page in seq.allocated_pages:
2593-
page.sub_ref()
2594-
seq.allocated_pages = []
2592+
if seq.allocated_pages is not None:
2593+
for page in seq.allocated_pages:
2594+
page.sub_ref()
2595+
seq.allocated_pages = []
25952596

25962597
self.generator.validate_cache()

exllamav2/mrope.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def gen_mrope_embed(
3636

3737
# Create 3D position IDs
3838

39-
ids = input_ids.squeeze(0)
39+
ids = input_ids.squeeze(0).contiguous()
4040
mrope_pos_ids = torch.zeros((3, max_length), dtype = torch.long).contiguous()
4141
merge_size = 1 if not embeddings else embeddings[0].model.config.vision_spatial_merge_size
4242
spans = []

exllamav2/version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.2.5"
1+
__version__ = "0.2.6"

0 commit comments

Comments
 (0)