Skip to content

Commit 47c38a7

Browse files
committedJul 26, 2023
Merged generate.py and full.py and removed generate.py
1 parent 03f5d5e commit 47c38a7

File tree

8 files changed

+101
-189
lines changed

8 files changed

+101
-189
lines changed
 

‎README.md

+4-4
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ To generate text predictions, you need to download the model weights. **If you d
7777
Run inference:
7878

7979
```bash
80-
python generate.py --prompt "Hello, my name is"
80+
python generate/full.py --prompt "Hello, my name is"
8181
```
8282

8383
This will run the 7B model and require ~26 GB of GPU memory (A100 GPU).
@@ -86,14 +86,14 @@ This will run the 7B model and require ~26 GB of GPU memory (A100 GPU).
8686

8787
### Run Lit-LLaMA on consumer devices
8888

89-
On GPUs with `bfloat16` support, the `generate.py` script will automatically convert the weights and consume about ~14 GB.
89+
On GPUs with `bfloat16` support, the `full.py` script will automatically convert the weights and consume about ~14 GB.
9090
For GPUs with less memory, or ones that don't support `bfloat16`, enable quantization (`--quantize llm.int8`):
9191

9292
```bash
93-
python generate.py --quantize llm.int8 --prompt "Hello, my name is"
93+
python generate/full.py --quantize llm.int8 --prompt "Hello, my name is"
9494
```
9595

96-
See `python generate.py --help` for more options.
96+
See `python generate/full.py --help` for more options.
9797

9898
You can also use GPTQ-style int4 quantization, but this needs conversions of the weights first:
9999

‎generate.py

-170
This file was deleted.

‎generate/adapter.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
wd = Path(__file__).parent.parent.resolve()
1212
sys.path.append(str(wd))
1313

14-
from generate import generate
14+
from generate.generate_utils import generate
1515
from lit_llama import Tokenizer
1616
from lit_llama.adapter import LLaMA
1717
from lit_llama.utils import lazy_load, llama_model_lookup, quantization

‎generate/adapter_v2.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
wd = Path(__file__).parent.parent.resolve()
1212
sys.path.append(str(wd))
1313

14-
from generate import generate
14+
from generate.generate_utils import generate
1515
from lit_llama import Tokenizer
1616
from lit_llama.adapter import LLaMA
1717
from lit_llama.utils import lazy_load, llama_model_lookup, quantization

‎generate/full.py

+14-9
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,9 @@
1212
sys.path.append(str(wd))
1313

1414
from lit_llama import LLaMA, Tokenizer
15-
from lit_llama.utils import quantization
15+
from lit_llama.utils import quantization, lazy_load, llama_model_lookup
1616
from scripts.prepare_alpaca import generate_prompt
17-
from generate import generate
18-
17+
from generate.generate_utils import generate
1918

2019
def main(
2120
prompt: str = "Hello, my name is",
@@ -28,6 +27,7 @@ def main(
2827
tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"),
2928
model_size: str = "7B",
3029
quantize: Optional[str] = None,
30+
instruction_tuning: Optional[bool] = False
3131
) -> None:
3232
"""Generates text samples based on a pre-trained LLaMA model and tokenizer.
3333
@@ -44,6 +44,7 @@ def main(
4444
quantize: Whether to quantize the model and using which method:
4545
``"llm.int8"``: LLM.int8() mode,
4646
``"gptq.int4"``: GPTQ 4-bit mode.
47+
instruction_tuning: Whether to regenerate sample in instruction turning format.
4748
"""
4849
if not checkpoint_path:
4950
checkpoint_path = Path(f"checkpoints/lit-llama/{model_size}/lit-llama.pth")
@@ -56,19 +57,23 @@ def main(
5657
print("Loading model ...", file=sys.stderr)
5758
t0 = time.time()
5859

59-
with fabric.init_module(empty_init=True), quantization(mode=quantize):
60-
model = LLaMA.from_name(model_size)
60+
with lazy_load(checkpoint_path) as checkpoint:
61+
name = llama_model_lookup(checkpoint)
62+
63+
with fabric.init_module(empty_init=True), quantization(mode=quantize):
64+
model = LLaMA.from_name(name)
6165

62-
checkpoint = torch.load(checkpoint_path)
63-
model.load_state_dict(checkpoint)
66+
model.load_state_dict(checkpoint)
6467
print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr)
6568

6669
model.eval()
6770
model = fabric.setup(model)
6871

6972
tokenizer = Tokenizer(tokenizer_path)
70-
sample = {"instruction": prompt, "input": input}
71-
prompt = generate_prompt(sample)
73+
74+
if instruction_tuning:
75+
sample = {"instruction": prompt, "input": input}
76+
prompt = generate_prompt(sample)
7277
encoded = tokenizer.encode(prompt, bos=True, eos=False, device=fabric.device)
7378
prompt_length = encoded.size(0)
7479

‎generate/generate_utils.py

+77
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import lightning as L
2+
import torch
3+
from typing import Optional
4+
from lit_llama import LLaMA
5+
6+
@torch.no_grad()
7+
def generate(
8+
model: LLaMA,
9+
idx: torch.Tensor,
10+
max_new_tokens: int,
11+
*,
12+
max_seq_length: Optional[int] = None,
13+
temperature: float = 1.0,
14+
top_k: Optional[int] = None,
15+
eos_id: Optional[int] = None,
16+
) -> torch.Tensor:
17+
"""Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
18+
19+
The implementation of this function is modified from A. Karpathy's nanoGPT.
20+
21+
Args:
22+
model: The model to use.
23+
idx: Tensor of shape (T) with indices of the prompt sequence.
24+
max_new_tokens: The number of new tokens to generate.
25+
max_seq_length: The maximum sequence length allowed.
26+
temperature: Scales the predicted logits by 1 / temperature
27+
top_k: If specified, only sample among the tokens with the k highest probabilities
28+
eos_id: If specified, stop generating any more token once the <eos> token is triggered
29+
"""
30+
# create an empty tensor of the expected final shape and fill in the current tokens
31+
T = idx.size(0)
32+
T_new = T + max_new_tokens
33+
if max_seq_length is None:
34+
max_seq_length = min(T_new, model.config.block_size)
35+
36+
device, dtype = idx.device, idx.dtype
37+
# create an empty tensor of the expected final shape and fill in the current tokens
38+
empty = torch.empty(T_new, dtype=dtype, device=device)
39+
empty[:T] = idx
40+
idx = empty
41+
input_pos = torch.arange(0, T, device=device)
42+
43+
if idx.device.type == "xla":
44+
import torch_xla.core.xla_model as xm
45+
46+
xm.mark_step()
47+
48+
# generate max_new_tokens tokens
49+
for _ in range(max_new_tokens):
50+
x = idx.index_select(0, input_pos).view(1, -1)
51+
52+
# forward
53+
logits = model(x, max_seq_length, input_pos)
54+
logits = logits[0, -1] / temperature
55+
56+
# optionally crop the logits to only the top k options
57+
if top_k is not None:
58+
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
59+
logits = torch.where(logits < v[[-1]], -float("Inf"), logits)
60+
61+
probs = torch.nn.functional.softmax(logits, dim=-1)
62+
idx_next = torch.multinomial(probs, num_samples=1).to(dtype=dtype)
63+
64+
# advance
65+
input_pos = input_pos[-1:] + 1
66+
67+
if idx.device.type == "xla":
68+
xm.mark_step()
69+
70+
# concatenate the new generation
71+
idx = idx.index_copy(0, input_pos, idx_next)
72+
73+
# if <eos> token is triggered, return the output (stop generation)
74+
if idx_next == eos_id:
75+
return idx[:input_pos] # include the EOS token
76+
77+
return idx

‎generate/lora.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@
1111
wd = Path(__file__).parent.parent.resolve()
1212
sys.path.append(str(wd))
1313

14-
from generate import generate
1514
from lit_llama import Tokenizer, LLaMA
1615
from lit_llama.lora import lora
1716
from lit_llama.utils import lazy_load, llama_model_lookup
17+
from generate.generate_utils import generate
1818
from scripts.prepare_alpaca import generate_prompt
1919

2020
lora_r = 8

‎tests/test_generate.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616
def load_generate_script():
1717
sys.path.append(str(wd))
1818

19-
import generate as generate
19+
from generate import full
2020

21-
return generate
21+
return full
2222

2323

2424
def test_generate():
@@ -111,7 +111,7 @@ def init_module(self, empty_init):
111111

112112

113113
def test_cli():
114-
cli_path = wd / "generate.py"
114+
cli_path = wd / "generate/full.py"
115115
output = subprocess.check_output([sys.executable, cli_path, "-h"])
116116
output = str(output.decode())
117117
assert "Generates text samples" in output

0 commit comments

Comments
 (0)