Skip to content

Commit 10f8053

Browse files
chore: add llm hybrid inference use case
1 parent 4ec4e65 commit 10f8053

File tree

1 file changed

+265
-0
lines changed

1 file changed

+265
-0
lines changed
Lines changed: 265 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,265 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"id": "dfccd8e6",
6+
"metadata": {},
7+
"source": [
8+
"# Fine-Tuning GPT-2 on Encrypted Data with LoRA and Concrete ML\n",
9+
"\n",
10+
"In this notebook, we perform fine-tuning of a GPT-2 model using LoRA and Concrete ML."
11+
]
12+
},
13+
{
14+
"cell_type": "code",
15+
"execution_count": 1,
16+
"id": "eca73e44",
17+
"metadata": {},
18+
"outputs": [
19+
{
20+
"data": {
21+
"text/plain": [
22+
"<torch._C.Generator at 0x779ae136e650>"
23+
]
24+
},
25+
"execution_count": 1,
26+
"metadata": {},
27+
"output_type": "execute_result"
28+
}
29+
],
30+
"source": [
31+
"# Import necessary libraries\n",
32+
"import math\n",
33+
"import os\n",
34+
"import random\n",
35+
"import shutil\n",
36+
"from pathlib import Path\n",
37+
"\n",
38+
"import matplotlib.pyplot as plt\n",
39+
"import numpy as np\n",
40+
"import torch\n",
41+
"from datasets import Dataset\n",
42+
"from peft import LoraConfig, get_peft_model\n",
43+
"from tqdm import tqdm\n",
44+
"from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments\n",
45+
"\n",
46+
"from concrete.ml.torch.hybrid_model import HybridFHEModel\n",
47+
"\n",
48+
"# Set random seed for reproducibility\n",
49+
"SEED = 0\n",
50+
"torch.manual_seed(SEED)"
51+
]
52+
},
53+
{
54+
"cell_type": "code",
55+
"execution_count": 2,
56+
"id": "c082411e",
57+
"metadata": {},
58+
"outputs": [],
59+
"source": [
60+
"def generate_and_print(prompt, model, tokenizer, seed=None, max_new_tokens=30):\n",
61+
" \"\"\"\n",
62+
" Generates text based on the provided prompt and prints both the prompt and the generated text.\n",
63+
"\n",
64+
" Args:\n",
65+
" prompt (str): The input prompt to generate text from.\n",
66+
" model: The pre-trained language model.\n",
67+
" tokenizer: The tokenizer associated with the model.\n",
68+
" seed (int, optional): Seed for random number generators to ensure reproducibility.\n",
69+
" max_new_tokens (int, optional): Maximum number of tokens to generate. Defaults to 30.\n",
70+
" Returns:\n",
71+
" str: The generated text (response only, without the prompt).\n",
72+
" \"\"\"\n",
73+
" try:\n",
74+
" # Set the environment variable for CuBLAS deterministic behavior\n",
75+
" os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] = \":4096:8\"\n",
76+
"\n",
77+
" # Set the random seed for reproducibility\n",
78+
" if seed is not None:\n",
79+
" random.seed(seed)\n",
80+
" np.random.seed(seed)\n",
81+
" torch.manual_seed(seed)\n",
82+
" if torch.cuda.is_available():\n",
83+
" torch.cuda.manual_seed_all(seed)\n",
84+
"\n",
85+
" # Encode the input prompt\n",
86+
" inputs = tokenizer.encode_plus(prompt, return_tensors=\"pt\")\n",
87+
"\n",
88+
" # Move inputs to the same device as the model\n",
89+
" inputs = {k: v for k, v in inputs.items()}\n",
90+
"\n",
91+
" # Generate text\n",
92+
" with torch.no_grad():\n",
93+
" output = model.generate(\n",
94+
" input_ids=inputs[\"input_ids\"],\n",
95+
" attention_mask=inputs[\"attention_mask\"],\n",
96+
" max_new_tokens=max_new_tokens,\n",
97+
" top_p=0.9,\n",
98+
" temperature=0.6,\n",
99+
" do_sample=True,\n",
100+
" pad_token_id=tokenizer.eos_token_id,\n",
101+
" )\n",
102+
"\n",
103+
" # Get only the newly generated tokens\n",
104+
" input_length = inputs[\"input_ids\"].shape[1]\n",
105+
" generated_ids = output[0, input_length:]\n",
106+
" generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()\n",
107+
"\n",
108+
" # Print the prompt and generated text\n",
109+
" print(f\"Prompt: {prompt}\")\n",
110+
" print(f\"Response: {generated_text}\\n\")\n",
111+
"\n",
112+
" return generated_text\n",
113+
"\n",
114+
" except Exception as e:\n",
115+
" print(f\"Error in generation: {str(e)}\")\n",
116+
" return None"
117+
]
118+
},
119+
{
120+
"cell_type": "code",
121+
"execution_count": 3,
122+
"id": "8b965a1a",
123+
"metadata": {},
124+
"outputs": [],
125+
"source": [
126+
"# Load pre-trained GPT-2 model and tokenizer\n",
127+
"model_name = \"gpt2\"\n",
128+
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
129+
"model = AutoModelForCausalLM.from_pretrained(model_name)\n",
130+
"\n",
131+
"# Ensure tokenizer has a pad token\n",
132+
"if tokenizer.pad_token is None:\n",
133+
" tokenizer.pad_token = tokenizer.eos_token\n",
134+
"model.config.pad_token_id = model.config.eos_token_id\n",
135+
"\n",
136+
"# Freeze model weights\n",
137+
"for param in model.parameters():\n",
138+
" param.requires_grad = False"
139+
]
140+
},
141+
{
142+
"cell_type": "code",
143+
"execution_count": 4,
144+
"id": "2337a6b4",
145+
"metadata": {},
146+
"outputs": [
147+
{
148+
"name": "stdout",
149+
"output_type": "stream",
150+
"text": [
151+
"Prompt: Programming is\n",
152+
"Response: a skill you need to learn to master.\n",
153+
"\n",
154+
"Learn to code\n",
155+
"\n",
156+
"There are a lot of different ways to learn programming.\n",
157+
"\n",
158+
"The\n",
159+
"\n"
160+
]
161+
}
162+
],
163+
"source": [
164+
"_ = generate_and_print(prompt=\"Programming is\", model=model, tokenizer=tokenizer, seed=SEED)"
165+
]
166+
},
167+
{
168+
"cell_type": "code",
169+
"execution_count": 5,
170+
"id": "a138d226",
171+
"metadata": {},
172+
"outputs": [],
173+
"source": [
174+
"from torch import nn\n",
175+
"\n",
176+
"try:\n",
177+
" from transformers import Conv1D as TransformerConv1D\n",
178+
"except ImportError: # pragma: no cover\n",
179+
" TransformerConv1D = None\n",
180+
"\n",
181+
"# Create a tuple of linear layer classes to check against\n",
182+
"LINEAR_LAYERS: tuple = (nn.Linear,)\n",
183+
"if TransformerConv1D is not None:\n",
184+
" LINEAR_LAYERS = LINEAR_LAYERS + (TransformerConv1D,)\n",
185+
"\n",
186+
"remote_names = []\n",
187+
"for name, module in model.named_modules():\n",
188+
" # Handle different module types\n",
189+
" if isinstance(module, LINEAR_LAYERS):\n",
190+
" remote_names.append(name)"
191+
]
192+
},
193+
{
194+
"cell_type": "code",
195+
"execution_count": 6,
196+
"id": "ae2094a4",
197+
"metadata": {},
198+
"outputs": [],
199+
"source": [
200+
"# Create the HybridFHEModel with the specified remote modules\n",
201+
"hybrid_model = HybridFHEModel(model, module_names=remote_names)"
202+
]
203+
},
204+
{
205+
"cell_type": "code",
206+
"execution_count": 7,
207+
"id": "20dfe2d8",
208+
"metadata": {},
209+
"outputs": [
210+
{
211+
"data": {
212+
"application/vnd.jupyter.widget-view+json": {
213+
"model_id": "c7db65de06c84890a25a4eb44d662bd1",
214+
"version_major": 2,
215+
"version_minor": 0
216+
},
217+
"text/plain": [
218+
"Compiling FHE layers: 0%| | 0/49 [00:00<?, ?it/s]"
219+
]
220+
},
221+
"metadata": {},
222+
"output_type": "display_data"
223+
}
224+
],
225+
"source": [
226+
"BLOCK_SIZE = 32\n",
227+
"# Prepare input data for calibration\n",
228+
"input_tensor = torch.randint(0, tokenizer.vocab_size, (256, BLOCK_SIZE), dtype=torch.long)\n",
229+
"\n",
230+
"# Calibrate and compile the model\n",
231+
"hybrid_model.compile_model(input_tensor, n_bits=8, use_dynamic_quantization=True)"
232+
]
233+
},
234+
{
235+
"cell_type": "markdown",
236+
"id": "65d448c8",
237+
"metadata": {},
238+
"source": [
239+
"Note that our goal is to showcase the use of FHE for encrypted fine-tuning. The dataset consists of 68 examples and a total of 2,386 tokens, which is relatively small. Despite its limited size, which offers little support for the model's learning process, it still manages to produce interesting results."
240+
]
241+
},
242+
{
243+
"cell_type": "code",
244+
"execution_count": null,
245+
"id": "3e91ad0b",
246+
"metadata": {},
247+
"outputs": [],
248+
"source": [
249+
"# Set FHE mode to disable for text generation\n",
250+
"hybrid_model.set_fhe_mode(\"disable\")\n",
251+
"\n",
252+
"_ = generate_and_print(\n",
253+
" prompt=\"Programming is\", model=hybrid_model.model, tokenizer=tokenizer, seed=SEED\n",
254+
")"
255+
]
256+
}
257+
],
258+
"metadata": {
259+
"execution": {
260+
"timeout": 10800
261+
}
262+
},
263+
"nbformat": 4,
264+
"nbformat_minor": 5
265+
}

0 commit comments

Comments
 (0)