Skip to content

Commit 506c6f4

Browse files
committed
feat: add detokenize method to Model
1 parent 2613efe commit 506c6f4

File tree

4 files changed

+53
-0
lines changed

4 files changed

+53
-0
lines changed

ext/llama_cpp/dummy.rb

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -577,6 +577,14 @@ def has_encoder?; end # rubocop:disable Naming/PredicateName
577577
# Returns the token id that must be provided to the decoder to start generating output sequence for encoder-decoder model.
578578
# @return [Integer]
579579
def decoder_start_token; end
580+
581+
# Returns the text that is converted from the given tokens.
582+
#
583+
# @param tokens [Array<Integer>] The tokens.
584+
# @param remove_special [Boolean] The flag whether to allow removing BOS and EOS tokens.
585+
# @param unparse_special [Boolean] The flag whether to render special tokens in the output.
586+
# @return [String]
587+
def detokenize(tokens, remove_special: false, unparse_special: false); end
580588
end
581589

582590
# Class for model KV override.

ext/llama_cpp/llama_cpp.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1554,6 +1554,7 @@ class RbLLaMAModel {
15541554
rb_define_method(rb_cLLaMAModel, "token_is_control?", RUBY_METHOD_FUNC(_llama_model_token_is_control), 1);
15551555
rb_define_method(rb_cLLaMAModel, "has_encoder?", RUBY_METHOD_FUNC(_llama_model_has_encoder), 0);
15561556
rb_define_method(rb_cLLaMAModel, "decoder_start_token", RUBY_METHOD_FUNC(_llama_model_decoder_start_token), 0);
1557+
rb_define_method(rb_cLLaMAModel, "detokenize", RUBY_METHOD_FUNC(_llama_model_detokenize), -1);
15571558
}
15581559

15591560
private:
@@ -1906,6 +1907,48 @@ class RbLLaMAModel {
19061907
LLaMAModelWrapper* ptr = get_llama_model(self);
19071908
return INT2NUM(llama_model_decoder_start_token(ptr->model));
19081909
}
1910+
1911+
static VALUE _llama_model_detokenize(int argc, VALUE* argv, VALUE self) {
1912+
VALUE kw_args = Qnil;
1913+
ID kw_table[2] = { rb_intern("remove_special"), rb_intern("unparse_special") };
1914+
VALUE kw_values[2] = { Qundef, Qundef };
1915+
VALUE tokens_ = Qnil;
1916+
rb_scan_args(argc, argv, "1:", &tokens_, &kw_args);
1917+
rb_get_kwargs(kw_args, kw_table, 0, 2, kw_values);
1918+
1919+
if (!RB_TYPE_P(tokens_, T_ARRAY)) {
1920+
rb_raise(rb_eArgError, "tokens must be an array");
1921+
return Qnil;
1922+
}
1923+
1924+
const int32_t n_tokens = RARRAY_LEN(tokens_);
1925+
llama_token* tokens = ALLOCA_N(llama_token, n_tokens);
1926+
for (int32_t i = 0; i < n_tokens; i++) {
1927+
tokens[i] = NUM2INT(rb_ary_entry(tokens_, i));
1928+
}
1929+
1930+
std::string text;
1931+
text.resize(std::max(text.capacity(), static_cast<unsigned long>(n_tokens)));
1932+
const int32_t text_len_max = text.size();
1933+
1934+
bool remove_special = kw_values[0] != Qundef ? RTEST(kw_values[0]) : false;
1935+
bool unparse_special = kw_values[1] != Qundef ? RTEST(kw_values[1]) : false;
1936+
1937+
LLaMAModelWrapper* ptr = get_llama_model(self);
1938+
std::string result;
1939+
int32_t n_chars = llama_detokenize(ptr->model, tokens, n_tokens, &text[0], text_len_max, remove_special, unparse_special);
1940+
if (n_chars < 0) {
1941+
text.resize(-n_chars);
1942+
n_chars = llama_detokenize(ptr->model, tokens, n_tokens, &text[0], text_len_max, remove_special, unparse_special);
1943+
if (n_chars <= text.size()) {
1944+
rb_raise(rb_eRuntimeError, "Failed to detokenize");
1945+
return Qnil;
1946+
}
1947+
}
1948+
1949+
text.resize(n_chars);
1950+
return rb_utf8_str_new_cstr(text.c_str());
1951+
}
19091952
};
19101953

19111954
const rb_data_type_t RbLLaMAModel::llama_model_type = {

ext/llama_cpp/llama_cpp.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#ifndef LLAMA_CPP_RB_H
22
#define LLAMA_CPP_RB_H 1
33

4+
#include <algorithm>
45
#include <sstream>
56
#include <string>
67
#include <vector>

sig/llama_cpp.rbs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ module LLaMACpp
188188
def token_is_control?: (Integer) -> bool
189189
def has_encoder?: () -> bool
190190
def decoder_start_token: () -> Integer
191+
def detokenize: (Array[Integer], ?remove_special: bool, ?unparse_special: bool) -> String
191192
end
192193

193194
class Timings

0 commit comments

Comments
 (0)