Skip to content

Commit 7702f47

Browse files
committed
chore!: fix generate module function with new API
1 parent b88ec9f commit 7702f47

File tree

1 file changed

+36
-80
lines changed

1 file changed

+36
-80
lines changed

lib/llama_cpp.rb

Lines changed: 36 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -9,98 +9,54 @@ module LlamaCpp
99

1010
# Generates sentences following the given prompt for operation check.
1111
#
12-
# @param context [LLaMACpp::Context] The context to use.
12+
# @param context [LlamaCpp::LlamaContext] The context to use.
1313
# @param prompt [String] The prompt to start generation with.
1414
# @param n_predict [Integer] The number of tokens to predict.
15-
# @param n_keep [Integer] The number of tokens to keep in the context.
16-
# @param n_batch [Integer] The number of tokens to process in a batch.
17-
# @param repeat_last_n [Integer] The number of tokens to consider for repetition penalty.
18-
# @param repeat_penalty [Float] The repetition penalty.
19-
# @param frequency [Float] The frequency penalty.
20-
# @param presence [Float] The presence penalty.
21-
# @param top_k [Integer] The number of tokens to consider for top-k sampling.
22-
# @param top_p [Float] The probability threshold for nucleus sampling.
23-
# @param tfs_z [Float] The z parameter for tail-free sampling.
24-
# @param typical_p [Float] The probability for typical sampling.
25-
# @param temperature [Float] The temperature for temperature sampling.
2615
# @return [String]
27-
def generate(context, prompt, # rubocop:disable Metrics/AbcSize, Metrics/CyclomaticComplexity, Metrics/MethodLength, Metrics/ParameterLists, Metrics/PerceivedComplexity
28-
n_predict: 128, n_keep: 10, n_batch: 512, repeat_last_n: 64,
29-
repeat_penalty: 1.1, frequency: 0.0, presence: 0.0, top_k: 40,
30-
top_p: 0.95, tfs_z: 1.0, typical_p: 1.0, temperature: 0.8)
31-
raise ArgumentError, 'context must be an instance of LLaMACpp::Context' unless context.is_a?(LLaMACpp::Context)
16+
def generate(context, prompt, n_predict: 128) # rubocop:disable Metrics/AbcSize, Metrics/MethodLength
17+
raise ArgumentError, 'context must be a LlamaContext' unless context.is_a?(LlamaCpp::LlamaContext)
3218
raise ArgumentError, 'prompt must be a String' unless prompt.is_a?(String)
3319

34-
spaced_prompt = " #{prompt}"
35-
embd_input = context.model.tokenize(text: spaced_prompt, add_bos: true)
20+
model = LlamaCpp.llama_get_model(context)
21+
vocab = LlamaCpp.llama_model_get_vocab(model)
3622

37-
n_ctx = context.n_ctx
38-
raise ArgumentError, "prompt is too long #{embd_input.size} tokens, maximum is #{n_ctx - 4}" if embd_input.size > n_ctx - 4
23+
n_prompt = -LlamaCpp.llama_tokenize(vocab, prompt, [], 0, true, true)
3924

40-
last_n_tokens = [0] * n_ctx
25+
prompt_tokens = []
26+
raise 'Failed to tokenize the prompt' if LlamaCpp.llama_tokenize(vocab, prompt, prompt_tokens, n_prompt, true,
27+
true).negative?
4128

42-
embd = []
43-
n_consumed = 0
44-
n_past = 0
45-
n_remain = n_predict
46-
n_vocab = context.model.n_vocab
29+
ctx_params = LlamaCpp::LlamaContextParams.new
30+
ctx_params.n_ctx = n_prompt + n_predict - 1
31+
ctx_params.n_batch = n_prompt
32+
ctx_params.no_perf = false
33+
34+
ctx = LlamaCpp.llama_init_from_model(model, ctx_params)
35+
36+
sparams = LlamaCpp::LlamaSamplerChainParams.new
37+
sparams.no_perf = false
38+
smpl = LlamaCpp.llama_sampler_chain_init(sparams)
39+
LlamaCpp.llama_sampler_chain_add(smpl, LlamaCpp.llama_sampler_init_greedy)
40+
41+
batch = LlamaCpp.llama_batch_get_one(prompt_tokens)
42+
43+
n_pos = 0
4744
output = []
45+
while n_pos + batch.n_tokens < n_prompt + n_predict
46+
break if LlamaCpp.llama_decode(ctx, batch) != 0
47+
48+
n_pos += batch.n_tokens
49+
50+
new_token_id = LlamaCpp.llama_sampler_sample(smpl, ctx, -1)
51+
break if llama_vocab_is_eog?(vocab, new_token_id)
52+
53+
buf = llama_token_to_piece(vocab, new_token_id, 0, true)
54+
output << buf
4855

49-
while n_remain != 0
50-
unless embd.empty?
51-
if n_past + embd.size > n_ctx
52-
n_left = n_past - n_keep
53-
n_past = n_keep
54-
embd.insert(0, last_n_tokens[(n_ctx - (n_left / 2) - embd.size)...-embd.size])
55-
end
56-
57-
context.decode(LLaMACpp::Batch.get_one(tokens: embd, n_tokens: embd.size, pos_zero: n_past, seq_id: 0))
58-
end
59-
60-
n_past += embd.size
61-
embd.clear
62-
63-
if embd_input.size <= n_consumed
64-
logits = context.logits
65-
base_candidates = Array.new(n_vocab) { |i| LLaMACpp::TokenData.new(id: i, logit: logits[i], p: 0.0) }
66-
candidates = LLaMACpp::TokenDataArray.new(base_candidates)
67-
68-
# apply penalties
69-
last_n_repeat = [last_n_tokens.size, repeat_last_n, n_ctx].min
70-
context.sample_repetition_penalties(
71-
candidates, last_n_tokens[-last_n_repeat..],
72-
penalty_repeat: repeat_penalty, penalty_freq: frequency, penalty_present: presence
73-
)
74-
75-
# temperature sampling
76-
context.sample_top_k(candidates, k: top_k)
77-
context.sample_tail_free(candidates, z: tfs_z)
78-
context.sample_typical(candidates, prob: typical_p)
79-
context.sample_top_p(candidates, prob: top_p)
80-
context.sample_temp(candidates, temp: temperature)
81-
id = context.sample_token(candidates)
82-
83-
last_n_tokens.shift
84-
last_n_tokens.push(id)
85-
86-
embd.push(id)
87-
n_remain -= 1
88-
else
89-
while embd_input.size > n_consumed
90-
embd.push(embd_input[n_consumed])
91-
last_n_tokens.shift
92-
last_n_tokens.push(embd_input[n_consumed])
93-
n_consumed += 1
94-
break if embd.size >= n_batch
95-
end
96-
end
97-
98-
embd.each { |token| output << context.model.token_to_piece(token) }
99-
100-
break if !embd.empty? && embd[-1] == context.model.token_eos
56+
batch = LlamaCpp.llama_batch_get_one([new_token_id])
10157
end
10258

103-
output.join.scrub('?').strip.delete_prefix(prompt).strip
59+
output.join
10460
end
10561
end
10662

0 commit comments

Comments
 (0)