Skip to content

chore: add llm hybrid inference use case #1061

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
302 changes: 302 additions & 0 deletions use_case_examples/llm/GPT2HybridInference.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,302 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "dfccd8e6",
"metadata": {},
"source": [
"# LLM Encrypted Token Generation\n",
"\n",
"This notebook shows how to configure a GPT2 model to generate text based on an encrypted prompt. The GPT2 model shown\n",
"here runs some layers on the client-side machine and some on the server-side using FHE:\n",
"- on the client-side: non-linear layers, such as attention, normalization and activation functions\n",
"- on the server-side: all linear layers that have trained weights\n",
"\n",
"To generate one token, the model shown here requires around 11 seconds on a desktop GPU."
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "eca73e44",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<torch._C.Generator at 0x73526e76a630>"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Import necessary libraries\n",
"import os\n",
"import random\n",
"\n",
"import numpy as np\n",
"import torch\n",
"from transformers import AutoModelForCausalLM, AutoTokenizer, Conv1D\n",
"\n",
"from concrete.ml.torch.hybrid_model import HybridFHEModel\n",
"\n",
"# Set random seed for reproducibility\n",
"SEED = 0\n",
"torch.manual_seed(SEED)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "c082411e",
"metadata": {},
"outputs": [],
"source": [
"def generate_and_print(prompt, model, tokenizer, seed=None, max_new_tokens=30):\n",
" \"\"\"\n",
" Generates text based on the provided prompt and prints both the prompt and the generated text.\n",
"\n",
" Args:\n",
" prompt (str): The input prompt to generate text from.\n",
" model: The pre-trained language model.\n",
" tokenizer: The tokenizer associated with the model.\n",
" seed (int, optional): Seed for random number generators to ensure reproducibility.\n",
" max_new_tokens (int, optional): Maximum number of tokens to generate. Defaults to 30.\n",
" Returns:\n",
" str: The generated text (response only, without the prompt).\n",
" \"\"\"\n",
" # Set the environment variable for CuBLAS deterministic behavior\n",
" os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] = \":4096:8\"\n",
"\n",
" # Set the random seed for reproducibility\n",
" if seed is not None:\n",
" random.seed(seed)\n",
" np.random.seed(seed)\n",
" torch.manual_seed(seed)\n",
" if torch.cuda.is_available():\n",
" torch.cuda.manual_seed_all(seed)\n",
"\n",
" # Encode the input prompt\n",
" inputs = tokenizer.encode_plus(prompt, return_tensors=\"pt\")\n",
"\n",
" # Generate text\n",
" with torch.no_grad():\n",
" output = model.generate(\n",
" input_ids=inputs[\"input_ids\"],\n",
" attention_mask=inputs[\"attention_mask\"],\n",
" max_new_tokens=max_new_tokens,\n",
" top_p=0.9,\n",
" temperature=0.6,\n",
" do_sample=True,\n",
" pad_token_id=tokenizer.eos_token_id,\n",
" )\n",
"\n",
" # Get only the newly generated tokens\n",
" input_length = inputs[\"input_ids\"].shape[1]\n",
" generated_ids = output[0, input_length:]\n",
" generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()\n",
"\n",
" # Print the prompt and generated text\n",
" print(f\"Prompt: {prompt}\")\n",
" print(f\"Response: {generated_text}\\n\")\n",
"\n",
" return generated_text"
]
},
{
"cell_type": "markdown",
"id": "794d9a40",
"metadata": {},
"source": [
"## Load the model for inference"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "8b965a1a",
"metadata": {},
"outputs": [],
"source": [
"# Load pre-trained GPT-2 model and tokenizer\n",
"model_name = \"gpt2\"\n",
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
"model = AutoModelForCausalLM.from_pretrained(model_name)\n",
"\n",
"# Ensure tokenizer has a pad token\n",
"if tokenizer.pad_token is None:\n",
" tokenizer.pad_token = tokenizer.eos_token\n",
"model.config.pad_token_id = model.config.eos_token_id\n",
"\n",
"# Freeze model weights\n",
"for param in model.parameters():\n",
" param.requires_grad = False"
]
},
{
"cell_type": "markdown",
"id": "8d798bd0",
"metadata": {},
"source": [
"## Generate some tokens on a clear-text prompt"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "2337a6b4",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Prompt: Programming is\n",
"Response: a skill you need to learn to master.\n",
"\n",
"Learn to code\n",
"\n",
"There are a lot of different ways to learn programming.\n",
"\n",
"The\n",
"\n"
]
}
],
"source": [
"_ = generate_and_print(prompt=\"Programming is\", model=model, tokenizer=tokenizer, seed=SEED)"
]
},
{
"cell_type": "markdown",
"id": "90ed3cbf",
"metadata": {},
"source": [
"## Configure the layers that execute remotely\n",
"\n",
"All linear layers in the GPT2 model can be detected by checking their Python type. Some layers\n",
"are expressed as 1d-Convolutional layers, but are executed as matrix-multiplications. The Concrete ML\n",
"Extensions backend provides functionality to execute such layers as encrypted-clear matrix multiplications."
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "a138d226",
"metadata": {},
"outputs": [],
"source": [
"remote_names = []\n",
"for name, module in model.named_modules():\n",
" if isinstance(module, (torch.nn.Linear, Conv1D)):\n",
" remote_names.append(name)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "ae2094a4",
"metadata": {},
"outputs": [],
"source": [
"# Create the HybridFHEModel with the specified remote modules\n",
"hybrid_model = HybridFHEModel(model, module_names=remote_names)"
]
},
{
"cell_type": "markdown",
"id": "b81ea79f",
"metadata": {},
"source": [
"## Compile the model\n",
"\n",
"This step determines the quantization parameters needed for inference. Dynamic quantization \n",
"is used and ensures quantization parameters are customized for each token of the token sequence."
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "20dfe2d8",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "6b36294657bb457687a13d7b043298e5",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Compiling FHE layers: 0%| | 0/49 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"BLOCK_SIZE = 32\n",
"# Prepare input data for calibration\n",
"input_tensor = torch.randint(0, tokenizer.vocab_size, (256, BLOCK_SIZE), dtype=torch.long)\n",
"\n",
"# Calibrate and compile the model\n",
"hybrid_model.compile_model(input_tensor, n_bits=8, use_dynamic_quantization=True)"
]
},
{
"cell_type": "markdown",
"id": "65d448c8",
"metadata": {},
"source": [
"## Generate a few encrypted tokens based on an encrypted prompt"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "3e91ad0b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Prompt: Programming is\n",
"Response: one of those things that, the first time around, the development team, the full-time team, would, in theory, acknowledge a\n",
"\n"
]
}
],
"source": [
"# Set FHE mode to disable for text generation\n",
"hybrid_model.set_fhe_mode(\"execute\")\n",
"\n",
"_ = generate_and_print(\n",
" prompt=\"Programming is\", model=hybrid_model.model, tokenizer=tokenizer, seed=SEED\n",
")"
]
},
{
"cell_type": "markdown",
"id": "3ab85824",
"metadata": {},
"source": [
"## Conclusion\n",
"\n",
"This notebook shows that with a few lines of code a GPT2 model can be converted \n",
"to execute on encrypted data. "
]
}
],
"metadata": {
"execution": {
"timeout": 10800
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Loading