Skip to content

Commit eefe9a8

Browse files
committed
feat: add lstrip and special keyword arguments to token_to_piece method
1 parent 055875c commit eefe9a8

File tree

3 files changed

+24
-7
lines changed

3 files changed

+24
-7
lines changed

ext/llama_cpp/dummy.rb

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -466,8 +466,10 @@ def rope_freq_scale_train; end
466466

467467
# Converts token to Ruby String.
468468
# @param token [Integer] The token to be converted.
469+
# @param lstrip [Integer] The number allows the user to skip up to 'lstrip' leading spaces before copying.
470+
# @param special [Boolean] The flag whether to allow rendering special tokens in the output.
469471
# @return [String]
470-
def token_to_piece(token); end
472+
def token_to_piece(token, lstrip: 0, special: false); end
471473

472474
# Returns the logits.
473475
#

ext/llama_cpp/llama_cpp.cpp

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1530,7 +1530,7 @@ class RbLLaMAModel {
15301530
rb_define_method(rb_cLLaMAModel, "n_embd", RUBY_METHOD_FUNC(_llama_model_get_model_n_embd), 0);
15311531
rb_define_method(rb_cLLaMAModel, "n_layer", RUBY_METHOD_FUNC(_llama_model_get_model_n_layer), 0);
15321532
rb_define_method(rb_cLLaMAModel, "rope_freq_scale_train", RUBY_METHOD_FUNC(_llama_model_rope_freq_scale_train), 0);
1533-
rb_define_method(rb_cLLaMAModel, "token_to_piece", RUBY_METHOD_FUNC(_llama_model_token_to_piece), 1);
1533+
rb_define_method(rb_cLLaMAModel, "token_to_piece", RUBY_METHOD_FUNC(_llama_model_token_to_piece), -1);
15341534
rb_define_method(rb_cLLaMAModel, "tokenize", RUBY_METHOD_FUNC(_llama_model_tokenize), -1);
15351535
rb_define_method(rb_cLLaMAModel, "desc", RUBY_METHOD_FUNC(_llama_model_get_model_desc), 0);
15361536
rb_define_method(rb_cLLaMAModel, "size", RUBY_METHOD_FUNC(_llama_model_get_model_size), 0);
@@ -1691,18 +1691,33 @@ class RbLLaMAModel {
16911691
return DBL2NUM(llama_rope_freq_scale_train(ptr->model));
16921692
}
16931693

1694-
static VALUE _llama_model_token_to_piece(VALUE self, VALUE token_) {
1694+
static VALUE _llama_model_token_to_piece(int argc, VALUE* argv, VALUE self) {
1695+
VALUE kw_args = Qnil;
1696+
ID kw_table[2] = { rb_intern("lstrip"), rb_intern("special") };
1697+
VALUE kw_values[2] = { Qundef, Qundef };
1698+
VALUE token_ = Qnil;
1699+
rb_scan_args(argc, argv, "1:", &token_, &kw_args);
1700+
rb_get_kwargs(kw_args, kw_table, 0, 2, kw_values);
1701+
16951702
if (!RB_INTEGER_TYPE_P(token_)) {
16961703
rb_raise(rb_eArgError, "token must be an integer");
16971704
return Qnil;
16981705
}
1706+
if (kw_values[0] != Qundef && !RB_INTEGER_TYPE_P(kw_values[0])) {
1707+
rb_raise(rb_eArgError, "lstrip must be an integer");
1708+
return Qnil;
1709+
}
1710+
16991711
const llama_token token = NUM2INT(token_);
1712+
const int32_t lstrip = kw_values[0] != Qundef ? NUM2INT(kw_values[0]) : 0;
1713+
const bool special = kw_values[1] != Qundef ? RTEST(kw_values[1]) : false;
1714+
17001715
LLaMAModelWrapper* ptr = get_llama_model(self);
17011716
std::vector<char> result(8, 0);
1702-
const int n_tokens = llama_token_to_piece(ptr->model, token, result.data(), result.size(), false);
1717+
const int n_tokens = llama_token_to_piece(ptr->model, token, result.data(), result.size(), lstrip, special);
17031718
if (n_tokens < 0) {
17041719
result.resize(-n_tokens);
1705-
const int check = llama_token_to_piece(ptr->model, token, result.data(), result.size(), false);
1720+
const int check = llama_token_to_piece(ptr->model, token, result.data(), result.size(), lstrip, special);
17061721
if (check != -n_tokens) {
17071722
rb_raise(rb_eRuntimeError, "failed to convert");
17081723
return Qnil;
@@ -2788,7 +2803,7 @@ class RbLLaMAContext {
27882803
ID kw_table[3] = { rb_intern("logits"), rb_intern("logits_guidance"), rb_intern("scale") };
27892804
VALUE kw_values[3] = { Qundef, Qundef, Qundef };
27902805
rb_scan_args(argc, argv, ":", &kw_args);
2791-
rb_get_kwargs(kw_args, kw_table, 0, 3, kw_values);
2806+
rb_get_kwargs(kw_args, kw_table, 3, 0, kw_values);
27922807

27932808
if (!RB_TYPE_P(kw_values[0], T_ARRAY)) {
27942809
rb_raise(rb_eArgError, "logits must be an Array");

sig/llama_cpp.rbs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ module LLaMACpp
164164
def n_embd: () -> Integer
165165
def n_layer: () -> Integer
166166
def rope_freq_scale_train: () -> Float
167-
def token_to_piece: (Integer) -> String
167+
def token_to_piece: (Integer, ?lstrip: Integer, ?special: bool) -> String
168168
def tokenize: (text: String, ?n_max_tokens: Integer, ?add_bos: bool, ?special: bool) -> Array[Integer]
169169
def desc: () -> String
170170
def size: () -> Integer

0 commit comments

Comments
 (0)