Skip to content

Commit bec8398

Browse files
committed
compress: format
1 parent b9a32f4 commit bec8398

File tree

1 file changed

+33
-27
lines changed

1 file changed

+33
-27
lines changed

examples/compress/compress.cpp

Lines changed: 33 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,8 @@ std::vector<uint8_t> encode(llama_context *ctx, std::vector<llama_token> inp, gp
7878
break;
7979
}
8080
}
81-
if(match<0){
81+
if (match < 0)
82+
{
8283
LOG_ERR("\n couldn't match %s", llama_token_to_piece(ctx, inp[index]).c_str());
8384
exit(1);
8485
}
@@ -133,14 +134,13 @@ std::vector<uint8_t> encode(llama_context *ctx, std::vector<llama_token> inp, gp
133134
int block_size = (bit_offset + PAD) / 8 - block_start;
134135
if (block_size >= 256)
135136
{
136-
// TODO: handle more than 256 bytes of block data
137+
// TODO: handle more than 256 bytes of block data
137138
// (maybe allow multiple blocks in a row)
138139
LOG_ERR("Block too big %d >= 256", block_size);
139140
exit(-1);
140141
}
141142
sample_ids_bitpacked[block_start + 1] = block_size & 0xff;
142143

143-
144144
// put last bytes
145145
if (PAD)
146146
{
@@ -212,7 +212,7 @@ std::vector<uint8_t> encode(llama_context *ctx, std::vector<llama_token> inp, gp
212212
int block_size = (bit_offset + PAD) / 8 - block_start;
213213
// endianness: big endian
214214
sample_ids_bitpacked[block_start + 1] = block_size & 0xff;
215-
total_pad+=PAD;
215+
total_pad += PAD;
216216
}
217217
llama_batch_free(batch);
218218
return sample_ids_bitpacked;
@@ -330,7 +330,7 @@ std::vector<llama_token> decode(llama_context *ctx, gpt_sampler *smpl, std::vect
330330

331331
auto &cur_p = smpl->cur_p; // initialized by set_logits
332332
llama_sampler_apply(smpl->chain, &cur_p);
333-
333+
334334
auto token_id = cur_p.data[sample_id].id;
335335
out.push_back(token_id);
336336
if (!inp.size() || token_id == inp[num_raw_tokens_header + index])
@@ -482,7 +482,7 @@ int main(int argc, char **argv)
482482
params.sparams.top_p = 1;
483483
params.sparams.top_k = -1;
484484
// Avoid temp=0 because greedy sampling breaks stuff
485-
params.sparams.temp = 1.;
485+
params.sparams.temp = 1.;
486486

487487
gpt_init();
488488

@@ -544,38 +544,43 @@ int main(int argc, char **argv)
544544
auto t_enc_end = ggml_time_us();
545545

546546
LOG("\n");
547-
if(!params.no_perf){
547+
if (!params.no_perf)
548+
{
548549
LOG("\nInput: %d characters (%d tokens)", params.prompt.length(), inp.size());
549550

550551
float compressed_bits_per_token = 8 * (float)sample_ids_bitpacked.size() / (float)inp.size();
551552
float compressed_bits_per_char = 8 * (float)sample_ids_bitpacked.size() / (float)params.prompt.length();
552553

553554
LOG("\n%d compressed bytes,(%04f bits per token, %04f bits per character)\n", (int)sample_ids_bitpacked.size(), compressed_bits_per_token, compressed_bits_per_char);
554-
LOG("\n%d padding bits, (%04f bits per character without padding)", total_pad, compressed_bits_per_char - total_pad/(float)params.prompt.length());
555-
LOG("\nPPL (over)estimation: %04f (%04f with padding)", exp2(compressed_bits_per_token-total_pad/(float)inp.size()),exp2(compressed_bits_per_token));
555+
LOG("\n%d padding bits, (%04f bits per character without padding)", total_pad, compressed_bits_per_char - total_pad / (float)params.prompt.length());
556+
LOG("\nPPL (over)estimation: %04f (%04f with padding)", exp2(compressed_bits_per_token - total_pad / (float)inp.size()), exp2(compressed_bits_per_token));
556557
}
557-
//maybe this needs to be changed
558-
if(params.out_file != "imatrix.dat"){
558+
// maybe this needs to be changed
559+
if (params.out_file != "imatrix.dat")
560+
{
559561
// dump uint8array to bin file
560562
std::ofstream ofs(params.out_file.c_str(), std::ios::binary);
561-
ofs.write((char*)&sample_ids_bitpacked[0], sample_ids_bitpacked.size());
563+
ofs.write((char *)&sample_ids_bitpacked[0], sample_ids_bitpacked.size());
562564
ofs.close();
563-
}else{
565+
}
566+
else
567+
{
564568
LOG("\n------------\n");
565-
//print as hex to stdout
566-
for (int i = 0; i < sample_ids_bitpacked.size(); i++){
569+
// print as hex to stdout
570+
for (int i = 0; i < sample_ids_bitpacked.size(); i++)
571+
{
567572
LOG("%02X ", sample_ids_bitpacked[i]);
568573
}
569574
}
570-
571575
}
572576
else if (params.compress_mode == 2)
573577
{
574-
//decompress mode
575-
// load sample_ids_bitpacked from params.prompt_file
578+
// decompress mode
579+
// load sample_ids_bitpacked from params.prompt_file
576580
std::ifstream ifs(params.prompt_file.c_str(), std::ios::binary);
577581

578-
if (!ifs) {
582+
if (!ifs)
583+
{
579584
LOG_ERR("%s: failed to open file\n", __func__);
580585
return -1;
581586
}
@@ -588,14 +593,16 @@ int main(int argc, char **argv)
588593
std::vector<uint8_t> sample_ids_bitpacked(fileSize);
589594

590595
// Read the ifs into the vector
591-
if (!ifs.read(reinterpret_cast<char*>(sample_ids_bitpacked.data()), fileSize)) {
596+
if (!ifs.read(reinterpret_cast<char *>(sample_ids_bitpacked.data()), fileSize))
597+
{
592598
LOG_ERR("%s: failed to read file\n", __func__);
593599
return -1;
594600
}
595601
ifs.close();
596602

597-
//Debug: print as hex
598-
for (int i = 0; i < sample_ids_bitpacked.size(); i++){
603+
// Debug: print as hex
604+
for (int i = 0; i < sample_ids_bitpacked.size(); i++)
605+
{
599606
LOG("%02X ", sample_ids_bitpacked[i]);
600607
}
601608
LOG("\n");
@@ -612,23 +619,22 @@ int main(int argc, char **argv)
612619

613620
std::vector<llama_token> out = decode(ctx, smpl, sample_ids_bitpacked);
614621

615-
616622
gpt_sampler_free(smpl);
617623
auto t_dec_end = ggml_time_us();
618624

619-
//maybe this needs to be changed
620-
if(params.out_file != "imatrix.dat"){
625+
// maybe this needs to be changed
626+
if (params.out_file != "imatrix.dat")
627+
{
621628
// dump as string to file
622629
std::string out_str = ::llama_detokenize(ctx, out);
623630

624631
std::ofstream ofs(params.out_file.c_str(), std::ios::binary);
625-
ofs.write((char*)&out_str[0], out_str.size());
632+
ofs.write((char *)&out_str[0], out_str.size());
626633
ofs.close();
627634
}
628635

629636
llama_free(ctx);
630637
llama_free_model(model);
631-
632638
}
633639

634640
llama_backend_free();

0 commit comments

Comments
 (0)