@@ -1530,7 +1530,7 @@ class RbLLaMAModel {
1530
1530
rb_define_method (rb_cLLaMAModel, " n_embd" , RUBY_METHOD_FUNC (_llama_model_get_model_n_embd), 0 );
1531
1531
rb_define_method (rb_cLLaMAModel, " n_layer" , RUBY_METHOD_FUNC (_llama_model_get_model_n_layer), 0 );
1532
1532
rb_define_method (rb_cLLaMAModel, " rope_freq_scale_train" , RUBY_METHOD_FUNC (_llama_model_rope_freq_scale_train), 0 );
1533
- rb_define_method (rb_cLLaMAModel, " token_to_piece" , RUBY_METHOD_FUNC (_llama_model_token_to_piece), 1 );
1533
+ rb_define_method (rb_cLLaMAModel, " token_to_piece" , RUBY_METHOD_FUNC (_llama_model_token_to_piece), - 1 );
1534
1534
rb_define_method (rb_cLLaMAModel, " tokenize" , RUBY_METHOD_FUNC (_llama_model_tokenize), -1 );
1535
1535
rb_define_method (rb_cLLaMAModel, " desc" , RUBY_METHOD_FUNC (_llama_model_get_model_desc), 0 );
1536
1536
rb_define_method (rb_cLLaMAModel, " size" , RUBY_METHOD_FUNC (_llama_model_get_model_size), 0 );
@@ -1691,18 +1691,33 @@ class RbLLaMAModel {
1691
1691
return DBL2NUM (llama_rope_freq_scale_train (ptr->model ));
1692
1692
}
1693
1693
1694
- static VALUE _llama_model_token_to_piece (VALUE self, VALUE token_) {
1694
+ static VALUE _llama_model_token_to_piece (int argc, VALUE* argv, VALUE self) {
1695
+ VALUE kw_args = Qnil;
1696
+ ID kw_table[2 ] = { rb_intern (" lstrip" ), rb_intern (" special" ) };
1697
+ VALUE kw_values[2 ] = { Qundef, Qundef };
1698
+ VALUE token_ = Qnil;
1699
+ rb_scan_args (argc, argv, " 1:" , &token_, &kw_args);
1700
+ rb_get_kwargs (kw_args, kw_table, 0 , 2 , kw_values);
1701
+
1695
1702
if (!RB_INTEGER_TYPE_P (token_)) {
1696
1703
rb_raise (rb_eArgError, " token must be an integer" );
1697
1704
return Qnil;
1698
1705
}
1706
+ if (kw_values[0 ] != Qundef && !RB_INTEGER_TYPE_P (kw_values[0 ])) {
1707
+ rb_raise (rb_eArgError, " lstrip must be an integer" );
1708
+ return Qnil;
1709
+ }
1710
+
1699
1711
const llama_token token = NUM2INT (token_);
1712
+ const int32_t lstrip = kw_values[0 ] != Qundef ? NUM2INT (kw_values[0 ]) : 0 ;
1713
+ const bool special = kw_values[1 ] != Qundef ? RTEST (kw_values[1 ]) : false ;
1714
+
1700
1715
LLaMAModelWrapper* ptr = get_llama_model (self);
1701
1716
std::vector<char > result (8 , 0 );
1702
- const int n_tokens = llama_token_to_piece (ptr->model , token, result.data (), result.size (), false );
1717
+ const int n_tokens = llama_token_to_piece (ptr->model , token, result.data (), result.size (), lstrip, special );
1703
1718
if (n_tokens < 0 ) {
1704
1719
result.resize (-n_tokens);
1705
- const int check = llama_token_to_piece (ptr->model , token, result.data (), result.size (), false );
1720
+ const int check = llama_token_to_piece (ptr->model , token, result.data (), result.size (), lstrip, special );
1706
1721
if (check != -n_tokens) {
1707
1722
rb_raise (rb_eRuntimeError, " failed to convert" );
1708
1723
return Qnil;
@@ -2788,7 +2803,7 @@ class RbLLaMAContext {
2788
2803
ID kw_table[3 ] = { rb_intern (" logits" ), rb_intern (" logits_guidance" ), rb_intern (" scale" ) };
2789
2804
VALUE kw_values[3 ] = { Qundef, Qundef, Qundef };
2790
2805
rb_scan_args (argc, argv, " :" , &kw_args);
2791
- rb_get_kwargs (kw_args, kw_table, 0 , 3 , kw_values);
2806
+ rb_get_kwargs (kw_args, kw_table, 3 , 0 , kw_values);
2792
2807
2793
2808
if (!RB_TYPE_P (kw_values[0 ], T_ARRAY)) {
2794
2809
rb_raise (rb_eArgError, " logits must be an Array" );
0 commit comments