Skip to content

Commit 731d8c6

Browse files
authored
Merge pull request #82 from huggingface/embd_combination
Refactored Image and Text combination to use placeholder tokens instead of simple concatination
2 parents 098db57 + 072441b commit 731d8c6

12 files changed

+187
-108
lines changed

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,6 @@ uv.lock
2323

2424
checkpoints/
2525
notebooks/
26-
*.slurm
26+
*.slurm
27+
28+
benchmark_results.json

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@
1313
1414
---
1515

16+
> [!NOTE]
17+
> We have pushed some breaking changes to the repository on June 4. To enable us to do smarter packing, we refactored the way image and text embeddings are combined. To keep everything as smooth as possible, we have trained a new nanoVLM-450M with this new pipeline, while leaving the old nanoVLM-222M compatible with the old pipeline If you clone this repository now or pull the updated to your local machine, the default will be the new 450M Model. If you would like a simpler understanding and a simpler codebase, you can use the v0.1 release. This works out of the box with the old 222M model.
18+
19+
---
20+
1621
nanoVLM is the simplest repository for training/finetuning a small sized Vision-Language Model with a lightweight implementation in pure PyTorch. The code itself is very readable and approachable, the model consists of a Vision Backbone (`models/vision_transformer.py` ~150 lines), Language Decoder (`models/language_model.py` ~250 lines), Modality Projection (`models/modality_projection.py` ~50 lines) and the VLM itself (`models/vision_language_model.py` ~100 lines) and a simple training loop (`train.py` ~200 lines).
1722

1823
Similar to Andrej Karpathy's nanoGPT, we wanted to equip the community with a very simple implementation and training script for Vision Language Models. We do not claim this to be a new SOTA model, rather an educational effort that packs quite a bit of punch if you have the right hardware! You should be able to tweak and play around with the code in no time.

benchmark-inference.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,14 @@ def generate_tokens(tokens, image):
1717
gen = model.generate(tokens, image, max_new_tokens=1000)
1818

1919
if __name__ == "__main__":
20-
model = VisionLanguageModel.from_pretrained("lusxvr/nanoVLM-222M").to(device)
20+
model = VisionLanguageModel.from_pretrained("lusxvr/nanoVLM-450M").to(device)
2121
model.eval()
2222

23-
tokenizer = get_tokenizer(model.cfg.lm_tokenizer)
23+
tokenizer = get_tokenizer(model.cfg.lm_tokenizer, model.cfg.vlm_extra_tokens)
2424
image_processor = get_image_processor(model.cfg.vit_img_size)
2525

2626
text = "What is this?"
27-
template = f"Question: {text} Answer:"
27+
template = f"{tokenizer.image_token * model.cfg.mp_image_token_length}Question: {text} Answer:"
2828
encoded_batch = tokenizer.batch_encode_plus([template], return_tensors="pt")
2929
tokens = encoded_batch['input_ids'].to(device)
3030

benchmark_suite.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def benchmark_vlm(
4545
vlm_load_backbone_weights=True
4646
)
4747
model = VisionLanguageModel(cfg, load_backbone=True).to(device).eval()
48-
tokenizer = get_tokenizer(cfg.lm_tokenizer)
48+
tokenizer = get_tokenizer(cfg.lm_tokenizer, cfg.vlm_extra_tokens)
4949
vit_img_size = int(cfg.vit_model_type[-3:]) # Kinda hacky, works for siglip models
5050
image_processor = get_image_processor(vit_img_size)
5151

data/collators.py

Lines changed: 43 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
import torch
22

33
class VQACollator(object): # Visual Question Answering Collator
4-
def __init__(self, tokenizer, max_length):
4+
def __init__(self, tokenizer, max_length, mp_image_token_length):
55
self.tokenizer = tokenizer
66
self.max_length = max_length
7+
self.mp_image_token_length = mp_image_token_length
8+
9+
self.image_token_str = tokenizer.image_token
710

811
def __call__(self, batch):
912
images = [item["image"] for item in batch]
@@ -13,10 +16,11 @@ def __call__(self, batch):
1316
# Stack images
1417
images = torch.stack(images)
1518

16-
# Create inputs by concatenating the question and answer
19+
# Create inputs by concatenating special image tokens, question, and answer
1720
input_sequences = []
1821
for i in range(len(texts)):
19-
input_sequences.append(f"{texts[i]}{answers[i]}")
22+
# Construct the image token segment string
23+
input_sequences.append(f"{self.image_token_str * self.mp_image_token_length}{texts[i]}{answers[i]}")
2024

2125
encoded_full_sequences = self.tokenizer.batch_encode_plus(
2226
input_sequences,
@@ -31,39 +35,41 @@ def __call__(self, batch):
3135
input_ids = encoded_full_sequences["input_ids"]
3236
attention_mask = encoded_full_sequences["attention_mask"]
3337
labels = input_ids.clone()
34-
labels[:, :-1] = input_ids[:, 1:].clone()
35-
labels[:, -1] = -100 #self.tokenizer.pad_token_id
36-
37-
# The tokenizer has different behavior for padding and truncation:
38-
# 1. If the full text (answer + question) is shorter than the max length, it gets padded on the left
39-
# 2. If the full text is longer than the max length, it gets truncated on the right
40-
# Therefore, I need to handle multiple cases, this is the different scenarios:
41-
# If the full text is longer than the max length, we need to set the labels to -100 for the whole sample (we want to ignore the whole sample)
42-
# If the full text is shorter than the max length, we need to set the labels to -100 only for the question part, and create causal language modeling labels for the answer part, taking into account the padding
38+
labels[:, :-1] = input_ids[:, 1:].clone() # Shift labels for causal LM
39+
labels[:, -1] = -100 # Last token has no target
4340

44-
# Determine if sequences were truncated
41+
# Determine original lengths before padding/truncation to handle truncation cases
4542
original_lengths = [len(self.tokenizer.encode(seq)) for seq in input_sequences]
46-
43+
4744
for i in range(len(batch)):
48-
# Get the length of the question for this sample
49-
question_length = len(self.tokenizer.encode(texts[i], add_special_tokens=False))
50-
5145
# Case 1: If sequence was truncated (original is longer than max_length)
5246
if original_lengths[i] > self.max_length:
53-
# Set all labels to -100 to ignore this sample entirely
54-
labels[i, :] = -100
55-
#print(f"Sample {i} was truncated. Setting all labels to -100.")
47+
labels[i, :] = -100 # Ignore this sample entirely
48+
# print(f"Sample {i} truncated: original length {original_lengths[i]} exceeds max_length {self.max_length}. Ignoring sample.")
5649
continue
5750

5851
# Case 2: Sequence fits within max_length
59-
# Use attention mask to find first non-padding token
60-
# The first 1 in the attention mask marks the first non-padding token
52+
# Determine the length of the question part for this sample
53+
question_part_length = len(self.tokenizer.encode(texts[i], add_special_tokens=False))
54+
55+
# Find the position of the first actual token (non-padding)
56+
# attention_mask might be all zeros if the sequence is fully truncated (handled above) or empty.
57+
# Ensure there's at least one non-padding token to avoid errors with .nonzero().
58+
if attention_mask[i].sum() == 0: # Should not happen if not truncated and not empty.
59+
labels[i, :] = -100 # Defensive: if no actual tokens, ignore sample
60+
continue
61+
6162
first_token_pos = attention_mask[i].nonzero(as_tuple=True)[0][0].item()
6263

63-
# Set labels for padding and question part to -100 (don't predict these), substracting 1 to account for the left shift
64-
question_end = first_token_pos + question_length - 1
65-
labels[i, :question_end] = -100
66-
# labels[i, original_lengths[i]-1:] = -100 # If you are using right padding
64+
# The total length of the "prompt" part (special image tokens + question)
65+
total_prompt_length = self.mp_image_token_length + question_part_length
66+
67+
# Mask labels for padding tokens (before first_token_pos) and the entire prompt part.
68+
# The prompt part starts at first_token_pos and has length total_prompt_length.
69+
# So, tokens from index 0 up to (first_token_pos + total_prompt_length - 1) should be masked.
70+
# The slicing labels[i, :N] masks indices 0 to N-1.
71+
mask_until_idx = first_token_pos + total_prompt_length - 1
72+
labels[i, :mask_until_idx] = -100
6773

6874
return {
6975
"image": images,
@@ -73,8 +79,11 @@ def __call__(self, batch):
7379
}
7480

7581
class MMStarCollator(object): # https://huggingface.co/datasets/Lin-Chen/MMStar
76-
def __init__(self, tokenizer):
82+
def __init__(self, tokenizer, mp_image_token_length):
7783
self.tokenizer = tokenizer
84+
self.mp_image_token_length = mp_image_token_length
85+
86+
self.image_token_str = tokenizer.image_token
7887

7988
def __call__(self, batch):
8089
images = [item["image"] for item in batch]
@@ -83,9 +92,15 @@ def __call__(self, batch):
8392

8493
# Stack images
8594
images = torch.stack(images)
95+
96+
# Create input sequences with image placeholders
97+
question_sequences = []
98+
for question_text in questions:
99+
question_sequences.append(f"{self.image_token_str * self.mp_image_token_length}{question_text}")
100+
86101

87102
encoded_question_sequences = self.tokenizer.batch_encode_plus(
88-
questions,
103+
question_sequences,
89104
padding=True,
90105
padding_side="left",
91106
return_tensors="pt"

data/processors.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,12 @@
33

44
TOKENIZERS_CACHE = {}
55

6-
def get_tokenizer(name):
6+
def get_tokenizer(name, extra_special_tokens=None):
77
if name not in TOKENIZERS_CACHE:
8-
tokenizer = AutoTokenizer.from_pretrained(name, use_fast=True)
8+
tokenizer_init_kwargs = {"use_fast": True}
9+
if extra_special_tokens is not None:
10+
tokenizer_init_kwargs["extra_special_tokens"] = extra_special_tokens
11+
tokenizer = AutoTokenizer.from_pretrained(name, **tokenizer_init_kwargs,)
912
tokenizer.pad_token = tokenizer.eos_token
1013
TOKENIZERS_CACHE[name] = tokenizer
1114
return TOKENIZERS_CACHE[name]

generate.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def parse_args():
1818
help="Path to a local checkpoint (directory or safetensors/pth). If omitted, we pull from HF."
1919
)
2020
parser.add_argument(
21-
"--hf_model", type=str, default="lusxvr/nanoVLM-222M",
21+
"--hf_model", type=str, default="lusxvr/nanoVLM-450M",
2222
help="HuggingFace repo ID to download from incase --checkpoint isnt set."
2323
)
2424
parser.add_argument("--image", type=str, default="assets/image.png",
@@ -48,10 +48,10 @@ def main():
4848
model = VisionLanguageModel.from_pretrained(source).to(device)
4949
model.eval()
5050

51-
tokenizer = get_tokenizer(model.cfg.lm_tokenizer)
51+
tokenizer = get_tokenizer(model.cfg.lm_tokenizer, model.cfg.vlm_extra_tokens)
5252
image_processor = get_image_processor(model.cfg.vit_img_size)
5353

54-
template = f"Question: {args.prompt} Answer:"
54+
template = f"{tokenizer.image_token * model.cfg.mp_image_token_length}Question: {args.prompt} Answer:"
5555
encoded = tokenizer.batch_encode_plus([template], return_tensors="pt")
5656
tokens = encoded["input_ids"].to(device)
5757

measure_vram.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def measure_vram(args, vlm_cfg, train_cfg_defaults):
4545

4646
# --- Dataset Preparation ---
4747
image_processor = get_image_processor(vlm_cfg.vit_img_size)
48-
tokenizer = get_tokenizer(vlm_cfg.lm_tokenizer)
48+
tokenizer = get_tokenizer(vlm_cfg.lm_tokenizer, vlm_cfg.vlm_extra_tokens)
4949

5050
dataset_path = train_cfg_defaults.train_dataset_path
5151
# train_cfg_defaults.train_dataset_name is a list, use the first if not specified
@@ -82,7 +82,7 @@ def measure_vram(args, vlm_cfg, train_cfg_defaults):
8282
return
8383

8484
processed_base_dataset = VQADataset(base_ds_for_vram_test, tokenizer, image_processor)
85-
vqa_collator = VQACollator(tokenizer, vlm_cfg.lm_max_length)
85+
vqa_collator = VQACollator(tokenizer, vlm_cfg.lm_max_length, vlm_cfg.mp_image_token_length)
8686

8787
print("\n--- VRAM Measurement ---")
8888
results = {}

models/config.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,56 +1,58 @@
1-
from dataclasses import dataclass
1+
from dataclasses import dataclass, field
22

33

44
@dataclass
55
class VLMConfig:
66
vit_hidden_dim: int = 768
77
vit_inter_dim: int = 4 * vit_hidden_dim
88
vit_patch_size: int = 16
9-
vit_img_size: int = 224
9+
vit_img_size: int = 256
1010
vit_n_heads: int = 12
1111
vit_dropout: float = 0.0
1212
vit_n_blocks: int = 12
1313
vit_ln_eps: float = 1e-6
1414
vit_cls_flag: bool = False
15-
vit_model_type: str = 'google/siglip-base-patch16-224'
15+
vit_model_type: str = 'google/siglip2-base-patch16-256'
1616

1717
lm_hidden_dim: int = 576
1818
lm_inter_dim: int = 1536
1919
lm_rms_eps: float = 1e-5
2020
lm_re_base: int = 100000
2121
lm_max_position_embeddings: int = 8192
22-
lm_vocab_size: int = 49152
22+
lm_base_vocab_size: int = 49152
23+
extra_token_amount: int = 1 # Number of extra tokens for the VLM (image start, image end, image token)
24+
lm_vocab_size: int = lm_base_vocab_size + extra_token_amount # Not a great way to do this, but it works for now (vlm_extra_tokens cannot be a dict, since this is mutable, and a Field has no len() function)
2325
lm_n_heads: int = 9
2426
lm_n_kv_heads: int = 3
2527
lm_dropout: float = 0.0
2628
lm_n_blocks: int = 30
2729
lm_attn_scaling: float = 1.0
28-
IMAGE_TOKEN_LENGTH: int = 49
29-
TOTAL_SEQUENCE_LENGTH: int = 128
30-
lm_max_length: int = TOTAL_SEQUENCE_LENGTH - IMAGE_TOKEN_LENGTH # Maximum length for the language model, derived from TOTAL_SEQUENCE_LENGTH and IMAGE_TOKEN_LENGTH
30+
lm_max_length: int = 512
3131
lm_use_tokens: bool = False # Decide if the LM expects tokens or embeddings as input (if using as a backbone for the VLM, set to False)
3232
lm_tie_weights: bool = True # Decide if you want to tie the LM Head weight to the token embedding weights
33-
lm_model_type: str = 'HuggingFaceTB/SmolLM2-135M'
34-
lm_tokenizer: str = 'HuggingFaceTB/cosmo2-tokenizer'
33+
lm_model_type: str = 'HuggingFaceTB/SmolLM2-360M-Instruct'
34+
lm_tokenizer: str = 'HuggingFaceTB/SmolLM2-360M-Instruct'
3535
lm_eos_token_id: int = 0
3636

3737
mp_pixel_shuffle_factor: int = 2
38+
mp_image_token_length: int = 64
3839

40+
vlm_extra_tokens: dict[str, str] = field(default_factory=lambda: {"image_token": "<|image|>"})#, "boi_token": "<|image_start|>", "eoi_token": "<|image_end|>"})
3941
vlm_load_backbone_weights: bool = True
4042
vlm_checkpoint_path: str = 'checkpoints'
4143
hf_repo_name: str = 'nanoVLM'
4244

4345

4446
@dataclass
4547
class TrainConfig:
46-
lr_mp: float = 2e-3
47-
lr_backbones: float = 1e-4
48+
lr_mp: float = 0.003
49+
lr_backbones: float = 5e-5
4850
data_cutoff_idx: int = None
4951
val_ratio: float = 0.025
50-
batch_size: int = 256
51-
gradient_accumulation_steps: int = 1
52+
batch_size: int = 32
53+
gradient_accumulation_steps: int = 4
5254
mmstar_batch_size: int = 32
53-
max_grad_norm: float = None
55+
max_grad_norm: float = 1.0
5456
eval_in_epochs: bool = True
5557
eval_interval: int = 250
5658
epochs: int = 5

0 commit comments

Comments
 (0)