@@ -9,98 +9,54 @@ module LlamaCpp
9
9
10
10
# Generates sentences following the given prompt for operation check.
11
11
#
12
- # @param context [LLaMACpp::Context ] The context to use.
12
+ # @param context [LlamaCpp::LlamaContext ] The context to use.
13
13
# @param prompt [String] The prompt to start generation with.
14
14
# @param n_predict [Integer] The number of tokens to predict.
15
- # @param n_keep [Integer] The number of tokens to keep in the context.
16
- # @param n_batch [Integer] The number of tokens to process in a batch.
17
- # @param repeat_last_n [Integer] The number of tokens to consider for repetition penalty.
18
- # @param repeat_penalty [Float] The repetition penalty.
19
- # @param frequency [Float] The frequency penalty.
20
- # @param presence [Float] The presence penalty.
21
- # @param top_k [Integer] The number of tokens to consider for top-k sampling.
22
- # @param top_p [Float] The probability threshold for nucleus sampling.
23
- # @param tfs_z [Float] The z parameter for tail-free sampling.
24
- # @param typical_p [Float] The probability for typical sampling.
25
- # @param temperature [Float] The temperature for temperature sampling.
26
15
# @return [String]
27
- def generate ( context , prompt , # rubocop:disable Metrics/AbcSize, Metrics/CyclomaticComplexity, Metrics/MethodLength, Metrics/ParameterLists, Metrics/PerceivedComplexity
28
- n_predict : 128 , n_keep : 10 , n_batch : 512 , repeat_last_n : 64 ,
29
- repeat_penalty : 1.1 , frequency : 0.0 , presence : 0.0 , top_k : 40 ,
30
- top_p : 0.95 , tfs_z : 1.0 , typical_p : 1.0 , temperature : 0.8 )
31
- raise ArgumentError , 'context must be an instance of LLaMACpp::Context' unless context . is_a? ( LLaMACpp ::Context )
16
+ def generate ( context , prompt , n_predict : 128 ) # rubocop:disable Metrics/AbcSize, Metrics/MethodLength
17
+ raise ArgumentError , 'context must be a LlamaContext' unless context . is_a? ( LlamaCpp ::LlamaContext )
32
18
raise ArgumentError , 'prompt must be a String' unless prompt . is_a? ( String )
33
19
34
- spaced_prompt = " #{ prompt } "
35
- embd_input = context . model . tokenize ( text : spaced_prompt , add_bos : true )
20
+ model = LlamaCpp . llama_get_model ( context )
21
+ vocab = LlamaCpp . llama_model_get_vocab ( model )
36
22
37
- n_ctx = context . n_ctx
38
- raise ArgumentError , "prompt is too long #{ embd_input . size } tokens, maximum is #{ n_ctx - 4 } " if embd_input . size > n_ctx - 4
23
+ n_prompt = -LlamaCpp . llama_tokenize ( vocab , prompt , [ ] , 0 , true , true )
39
24
40
- last_n_tokens = [ 0 ] * n_ctx
25
+ prompt_tokens = [ ]
26
+ raise 'Failed to tokenize the prompt' if LlamaCpp . llama_tokenize ( vocab , prompt , prompt_tokens , n_prompt , true ,
27
+ true ) . negative?
41
28
42
- embd = [ ]
43
- n_consumed = 0
44
- n_past = 0
45
- n_remain = n_predict
46
- n_vocab = context . model . n_vocab
29
+ ctx_params = LlamaCpp ::LlamaContextParams . new
30
+ ctx_params . n_ctx = n_prompt + n_predict - 1
31
+ ctx_params . n_batch = n_prompt
32
+ ctx_params . no_perf = false
33
+
34
+ ctx = LlamaCpp . llama_init_from_model ( model , ctx_params )
35
+
36
+ sparams = LlamaCpp ::LlamaSamplerChainParams . new
37
+ sparams . no_perf = false
38
+ smpl = LlamaCpp . llama_sampler_chain_init ( sparams )
39
+ LlamaCpp . llama_sampler_chain_add ( smpl , LlamaCpp . llama_sampler_init_greedy )
40
+
41
+ batch = LlamaCpp . llama_batch_get_one ( prompt_tokens )
42
+
43
+ n_pos = 0
47
44
output = [ ]
45
+ while n_pos + batch . n_tokens < n_prompt + n_predict
46
+ break if LlamaCpp . llama_decode ( ctx , batch ) != 0
47
+
48
+ n_pos += batch . n_tokens
49
+
50
+ new_token_id = LlamaCpp . llama_sampler_sample ( smpl , ctx , -1 )
51
+ break if llama_vocab_is_eog? ( vocab , new_token_id )
52
+
53
+ buf = llama_token_to_piece ( vocab , new_token_id , 0 , true )
54
+ output << buf
48
55
49
- while n_remain != 0
50
- unless embd . empty?
51
- if n_past + embd . size > n_ctx
52
- n_left = n_past - n_keep
53
- n_past = n_keep
54
- embd . insert ( 0 , last_n_tokens [ ( n_ctx - ( n_left / 2 ) - embd . size ) ...-embd . size ] )
55
- end
56
-
57
- context . decode ( LLaMACpp ::Batch . get_one ( tokens : embd , n_tokens : embd . size , pos_zero : n_past , seq_id : 0 ) )
58
- end
59
-
60
- n_past += embd . size
61
- embd . clear
62
-
63
- if embd_input . size <= n_consumed
64
- logits = context . logits
65
- base_candidates = Array . new ( n_vocab ) { |i | LLaMACpp ::TokenData . new ( id : i , logit : logits [ i ] , p : 0.0 ) }
66
- candidates = LLaMACpp ::TokenDataArray . new ( base_candidates )
67
-
68
- # apply penalties
69
- last_n_repeat = [ last_n_tokens . size , repeat_last_n , n_ctx ] . min
70
- context . sample_repetition_penalties (
71
- candidates , last_n_tokens [ -last_n_repeat ..] ,
72
- penalty_repeat : repeat_penalty , penalty_freq : frequency , penalty_present : presence
73
- )
74
-
75
- # temperature sampling
76
- context . sample_top_k ( candidates , k : top_k )
77
- context . sample_tail_free ( candidates , z : tfs_z )
78
- context . sample_typical ( candidates , prob : typical_p )
79
- context . sample_top_p ( candidates , prob : top_p )
80
- context . sample_temp ( candidates , temp : temperature )
81
- id = context . sample_token ( candidates )
82
-
83
- last_n_tokens . shift
84
- last_n_tokens . push ( id )
85
-
86
- embd . push ( id )
87
- n_remain -= 1
88
- else
89
- while embd_input . size > n_consumed
90
- embd . push ( embd_input [ n_consumed ] )
91
- last_n_tokens . shift
92
- last_n_tokens . push ( embd_input [ n_consumed ] )
93
- n_consumed += 1
94
- break if embd . size >= n_batch
95
- end
96
- end
97
-
98
- embd . each { |token | output << context . model . token_to_piece ( token ) }
99
-
100
- break if !embd . empty? && embd [ -1 ] == context . model . token_eos
56
+ batch = LlamaCpp . llama_batch_get_one ( [ new_token_id ] )
101
57
end
102
58
103
- output . join . scrub ( '?' ) . strip . delete_prefix ( prompt ) . strip
59
+ output . join
104
60
end
105
61
end
106
62
0 commit comments