@@ -1554,6 +1554,7 @@ class RbLLaMAModel {
1554
1554
rb_define_method (rb_cLLaMAModel, " token_is_control?" , RUBY_METHOD_FUNC (_llama_model_token_is_control), 1 );
1555
1555
rb_define_method (rb_cLLaMAModel, " has_encoder?" , RUBY_METHOD_FUNC (_llama_model_has_encoder), 0 );
1556
1556
rb_define_method (rb_cLLaMAModel, " decoder_start_token" , RUBY_METHOD_FUNC (_llama_model_decoder_start_token), 0 );
1557
+ rb_define_method (rb_cLLaMAModel, " detokenize" , RUBY_METHOD_FUNC (_llama_model_detokenize), -1 );
1557
1558
}
1558
1559
1559
1560
private:
@@ -1906,6 +1907,48 @@ class RbLLaMAModel {
1906
1907
LLaMAModelWrapper* ptr = get_llama_model (self);
1907
1908
return INT2NUM (llama_model_decoder_start_token (ptr->model ));
1908
1909
}
1910
+
1911
+ static VALUE _llama_model_detokenize (int argc, VALUE* argv, VALUE self) {
1912
+ VALUE kw_args = Qnil;
1913
+ ID kw_table[2 ] = { rb_intern (" remove_special" ), rb_intern (" unparse_special" ) };
1914
+ VALUE kw_values[2 ] = { Qundef, Qundef };
1915
+ VALUE tokens_ = Qnil;
1916
+ rb_scan_args (argc, argv, " 1:" , &tokens_, &kw_args);
1917
+ rb_get_kwargs (kw_args, kw_table, 0 , 2 , kw_values);
1918
+
1919
+ if (!RB_TYPE_P (tokens_, T_ARRAY)) {
1920
+ rb_raise (rb_eArgError, " tokens must be an array" );
1921
+ return Qnil;
1922
+ }
1923
+
1924
+ const int32_t n_tokens = RARRAY_LEN (tokens_);
1925
+ llama_token* tokens = ALLOCA_N (llama_token, n_tokens);
1926
+ for (int32_t i = 0 ; i < n_tokens; i++) {
1927
+ tokens[i] = NUM2INT (rb_ary_entry (tokens_, i));
1928
+ }
1929
+
1930
+ std::string text;
1931
+ text.resize (std::max (text.capacity (), static_cast <unsigned long >(n_tokens)));
1932
+ const int32_t text_len_max = text.size ();
1933
+
1934
+ bool remove_special = kw_values[0 ] != Qundef ? RTEST (kw_values[0 ]) : false ;
1935
+ bool unparse_special = kw_values[1 ] != Qundef ? RTEST (kw_values[1 ]) : false ;
1936
+
1937
+ LLaMAModelWrapper* ptr = get_llama_model (self);
1938
+ std::string result;
1939
+ int32_t n_chars = llama_detokenize (ptr->model , tokens, n_tokens, &text[0 ], text_len_max, remove_special, unparse_special);
1940
+ if (n_chars < 0 ) {
1941
+ text.resize (-n_chars);
1942
+ n_chars = llama_detokenize (ptr->model , tokens, n_tokens, &text[0 ], text_len_max, remove_special, unparse_special);
1943
+ if (n_chars <= text.size ()) {
1944
+ rb_raise (rb_eRuntimeError, " Failed to detokenize" );
1945
+ return Qnil;
1946
+ }
1947
+ }
1948
+
1949
+ text.resize (n_chars);
1950
+ return rb_utf8_str_new_cstr (text.c_str ());
1951
+ }
1909
1952
};
1910
1953
1911
1954
const rb_data_type_t RbLLaMAModel::llama_model_type = {
0 commit comments