Skip to content

Refactored Image and Text combination to use placeholder tokens instead of simple concatination #82

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 28 commits into from
Jun 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
6e53742
added special tokens
lusxvr May 27, 2025
0846093
made changes backwards compatible
lusxvr May 27, 2025
70df1dc
adapted collator to handle image replacement tokens
lusxvr May 27, 2025
7a3ec7e
adapted VLM to handle image replacement tokens (MMStar and therefore …
lusxvr May 27, 2025
04e0fd4
adapted evaluation
lusxvr May 28, 2025
80d36a6
Merge branch 'main' into embd_combination
lusxvr May 28, 2025
4dd05c8
comparison run to main
lusxvr May 28, 2025
00ea46e
changed token/sec calculation
lusxvr May 28, 2025
41f6996
fixed forward loop
lusxvr May 28, 2025
ee6c6fa
simplified logic and improved generate
lusxvr May 28, 2025
c1d21e4
ablation runs
lusxvr May 29, 2025
807da1e
fixed grad norm log when using grad accum
lusxvr May 30, 2025
1789ae3
test run
lusxvr May 30, 2025
52f99f0
back to old config
lusxvr May 30, 2025
3813a47
cleaned logging
lusxvr May 30, 2025
5079e4e
tried to fix generate (still not working)
lusxvr Jun 2, 2025
a507702
trained 450M model with new embeddings
lusxvr Jun 2, 2025
527856f
changed tokenizer back to cosmo
lusxvr Jun 3, 2025
6360627
fixed typo
lusxvr Jun 3, 2025
b8b19a5
changed default model in generate
lusxvr Jun 3, 2025
bd6e509
cleaned config
lusxvr Jun 3, 2025
6e92046
more comprehensive run dating and max grad norm
lusxvr Jun 3, 2025
9f9b028
cleaned and incorporated suggestions
lusxvr Jun 3, 2025
614d913
post-processed generate
lusxvr Jun 3, 2025
524ba29
cleaned branch for merge and checked compatibility
lusxvr Jun 3, 2025
29dd1c5
cleaned naming
lusxvr Jun 3, 2025
29706d9
fixed lr scheduler
lusxvr Jun 3, 2025
072441b
updated config and README
lusxvr Jun 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,6 @@ uv.lock

checkpoints/
notebooks/
*.slurm
*.slurm

benchmark_results.json
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@

---

> [!NOTE]
> We have pushed some breaking changes to the repository on June 4. To enable us to do smarter packing, we refactored the way image and text embeddings are combined. To keep everything as smooth as possible, we have trained a new nanoVLM-450M with this new pipeline, while leaving the old nanoVLM-222M compatible with the old pipeline If you clone this repository now or pull the updated to your local machine, the default will be the new 450M Model. If you would like a simpler understanding and a simpler codebase, you can use the v0.1 release. This works out of the box with the old 222M model.

---

nanoVLM is the simplest repository for training/finetuning a small sized Vision-Language Model with a lightweight implementation in pure PyTorch. The code itself is very readable and approachable, the model consists of a Vision Backbone (`models/vision_transformer.py` ~150 lines), Language Decoder (`models/language_model.py` ~250 lines), Modality Projection (`models/modality_projection.py` ~50 lines) and the VLM itself (`models/vision_language_model.py` ~100 lines) and a simple training loop (`train.py` ~200 lines).

Similar to Andrej Karpathy's nanoGPT, we wanted to equip the community with a very simple implementation and training script for Vision Language Models. We do not claim this to be a new SOTA model, rather an educational effort that packs quite a bit of punch if you have the right hardware! You should be able to tweak and play around with the code in no time.
Expand Down
6 changes: 3 additions & 3 deletions benchmark-inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@ def generate_tokens(tokens, image):
gen = model.generate(tokens, image, max_new_tokens=1000)

if __name__ == "__main__":
model = VisionLanguageModel.from_pretrained("lusxvr/nanoVLM-222M").to(device)
model = VisionLanguageModel.from_pretrained("lusxvr/nanoVLM-450M").to(device)
model.eval()

tokenizer = get_tokenizer(model.cfg.lm_tokenizer)
tokenizer = get_tokenizer(model.cfg.lm_tokenizer, model.cfg.vlm_extra_tokens)
image_processor = get_image_processor(model.cfg.vit_img_size)

text = "What is this?"
template = f"Question: {text} Answer:"
template = f"{tokenizer.image_token * model.cfg.mp_image_token_length}Question: {text} Answer:"
encoded_batch = tokenizer.batch_encode_plus([template], return_tensors="pt")
tokens = encoded_batch['input_ids'].to(device)

Expand Down
2 changes: 1 addition & 1 deletion benchmark_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def benchmark_vlm(
vlm_load_backbone_weights=True
)
model = VisionLanguageModel(cfg, load_backbone=True).to(device).eval()
tokenizer = get_tokenizer(cfg.lm_tokenizer)
tokenizer = get_tokenizer(cfg.lm_tokenizer, cfg.vlm_extra_tokens)
vit_img_size = int(cfg.vit_model_type[-3:]) # Kinda hacky, works for siglip models
image_processor = get_image_processor(vit_img_size)

Expand Down
71 changes: 43 additions & 28 deletions data/collators.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import torch

class VQACollator(object): # Visual Question Answering Collator
def __init__(self, tokenizer, max_length):
def __init__(self, tokenizer, max_length, mp_image_token_length):
self.tokenizer = tokenizer
self.max_length = max_length
self.mp_image_token_length = mp_image_token_length

self.image_token_str = tokenizer.image_token

def __call__(self, batch):
images = [item["image"] for item in batch]
Expand All @@ -13,10 +16,11 @@ def __call__(self, batch):
# Stack images
images = torch.stack(images)

# Create inputs by concatenating the question and answer
# Create inputs by concatenating special image tokens, question, and answer
input_sequences = []
for i in range(len(texts)):
input_sequences.append(f"{texts[i]}{answers[i]}")
# Construct the image token segment string
input_sequences.append(f"{self.image_token_str * self.mp_image_token_length}{texts[i]}{answers[i]}")

encoded_full_sequences = self.tokenizer.batch_encode_plus(
input_sequences,
Expand All @@ -31,39 +35,41 @@ def __call__(self, batch):
input_ids = encoded_full_sequences["input_ids"]
attention_mask = encoded_full_sequences["attention_mask"]
labels = input_ids.clone()
labels[:, :-1] = input_ids[:, 1:].clone()
labels[:, -1] = -100 #self.tokenizer.pad_token_id

# The tokenizer has different behavior for padding and truncation:
# 1. If the full text (answer + question) is shorter than the max length, it gets padded on the left
# 2. If the full text is longer than the max length, it gets truncated on the right
# Therefore, I need to handle multiple cases, this is the different scenarios:
# If the full text is longer than the max length, we need to set the labels to -100 for the whole sample (we want to ignore the whole sample)
# If the full text is shorter than the max length, we need to set the labels to -100 only for the question part, and create causal language modeling labels for the answer part, taking into account the padding
labels[:, :-1] = input_ids[:, 1:].clone() # Shift labels for causal LM
labels[:, -1] = -100 # Last token has no target

# Determine if sequences were truncated
# Determine original lengths before padding/truncation to handle truncation cases
original_lengths = [len(self.tokenizer.encode(seq)) for seq in input_sequences]

for i in range(len(batch)):
# Get the length of the question for this sample
question_length = len(self.tokenizer.encode(texts[i], add_special_tokens=False))

# Case 1: If sequence was truncated (original is longer than max_length)
if original_lengths[i] > self.max_length:
# Set all labels to -100 to ignore this sample entirely
labels[i, :] = -100
#print(f"Sample {i} was truncated. Setting all labels to -100.")
labels[i, :] = -100 # Ignore this sample entirely
# print(f"Sample {i} truncated: original length {original_lengths[i]} exceeds max_length {self.max_length}. Ignoring sample.")
continue

# Case 2: Sequence fits within max_length
# Use attention mask to find first non-padding token
# The first 1 in the attention mask marks the first non-padding token
# Determine the length of the question part for this sample
question_part_length = len(self.tokenizer.encode(texts[i], add_special_tokens=False))

# Find the position of the first actual token (non-padding)
# attention_mask might be all zeros if the sequence is fully truncated (handled above) or empty.
# Ensure there's at least one non-padding token to avoid errors with .nonzero().
if attention_mask[i].sum() == 0: # Should not happen if not truncated and not empty.
labels[i, :] = -100 # Defensive: if no actual tokens, ignore sample
continue

first_token_pos = attention_mask[i].nonzero(as_tuple=True)[0][0].item()

# Set labels for padding and question part to -100 (don't predict these), substracting 1 to account for the left shift
question_end = first_token_pos + question_length - 1
labels[i, :question_end] = -100
# labels[i, original_lengths[i]-1:] = -100 # If you are using right padding
# The total length of the "prompt" part (special image tokens + question)
total_prompt_length = self.mp_image_token_length + question_part_length

# Mask labels for padding tokens (before first_token_pos) and the entire prompt part.
# The prompt part starts at first_token_pos and has length total_prompt_length.
# So, tokens from index 0 up to (first_token_pos + total_prompt_length - 1) should be masked.
# The slicing labels[i, :N] masks indices 0 to N-1.
mask_until_idx = first_token_pos + total_prompt_length - 1
labels[i, :mask_until_idx] = -100

return {
"image": images,
Expand All @@ -73,8 +79,11 @@ def __call__(self, batch):
}

class MMStarCollator(object): # https://huggingface.co/datasets/Lin-Chen/MMStar
def __init__(self, tokenizer):
def __init__(self, tokenizer, mp_image_token_length):
self.tokenizer = tokenizer
self.mp_image_token_length = mp_image_token_length

self.image_token_str = tokenizer.image_token

def __call__(self, batch):
images = [item["image"] for item in batch]
Expand All @@ -83,9 +92,15 @@ def __call__(self, batch):

# Stack images
images = torch.stack(images)

# Create input sequences with image placeholders
question_sequences = []
for question_text in questions:
question_sequences.append(f"{self.image_token_str * self.mp_image_token_length}{question_text}")


encoded_question_sequences = self.tokenizer.batch_encode_plus(
questions,
question_sequences,
padding=True,
padding_side="left",
return_tensors="pt"
Expand Down
7 changes: 5 additions & 2 deletions data/processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@

TOKENIZERS_CACHE = {}

def get_tokenizer(name):
def get_tokenizer(name, extra_special_tokens=None):
if name not in TOKENIZERS_CACHE:
tokenizer = AutoTokenizer.from_pretrained(name, use_fast=True)
tokenizer_init_kwargs = {"use_fast": True}
if extra_special_tokens is not None:
tokenizer_init_kwargs["extra_special_tokens"] = extra_special_tokens
tokenizer = AutoTokenizer.from_pretrained(name, **tokenizer_init_kwargs,)
tokenizer.pad_token = tokenizer.eos_token
TOKENIZERS_CACHE[name] = tokenizer
return TOKENIZERS_CACHE[name]
Expand Down
6 changes: 3 additions & 3 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def parse_args():
help="Path to a local checkpoint (directory or safetensors/pth). If omitted, we pull from HF."
)
parser.add_argument(
"--hf_model", type=str, default="lusxvr/nanoVLM-222M",
"--hf_model", type=str, default="lusxvr/nanoVLM-450M",
help="HuggingFace repo ID to download from incase --checkpoint isnt set."
)
parser.add_argument("--image", type=str, default="assets/image.png",
Expand Down Expand Up @@ -48,10 +48,10 @@ def main():
model = VisionLanguageModel.from_pretrained(source).to(device)
model.eval()

tokenizer = get_tokenizer(model.cfg.lm_tokenizer)
tokenizer = get_tokenizer(model.cfg.lm_tokenizer, model.cfg.vlm_extra_tokens)
image_processor = get_image_processor(model.cfg.vit_img_size)

template = f"Question: {args.prompt} Answer:"
template = f"{tokenizer.image_token * model.cfg.mp_image_token_length}Question: {args.prompt} Answer:"
encoded = tokenizer.batch_encode_plus([template], return_tensors="pt")
tokens = encoded["input_ids"].to(device)

Expand Down
4 changes: 2 additions & 2 deletions measure_vram.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def measure_vram(args, vlm_cfg, train_cfg_defaults):

# --- Dataset Preparation ---
image_processor = get_image_processor(vlm_cfg.vit_img_size)
tokenizer = get_tokenizer(vlm_cfg.lm_tokenizer)
tokenizer = get_tokenizer(vlm_cfg.lm_tokenizer, vlm_cfg.vlm_extra_tokens)

dataset_path = train_cfg_defaults.train_dataset_path
# train_cfg_defaults.train_dataset_name is a list, use the first if not specified
Expand Down Expand Up @@ -82,7 +82,7 @@ def measure_vram(args, vlm_cfg, train_cfg_defaults):
return

processed_base_dataset = VQADataset(base_ds_for_vram_test, tokenizer, image_processor)
vqa_collator = VQACollator(tokenizer, vlm_cfg.lm_max_length)
vqa_collator = VQACollator(tokenizer, vlm_cfg.lm_max_length, vlm_cfg.mp_image_token_length)

print("\n--- VRAM Measurement ---")
results = {}
Expand Down
30 changes: 16 additions & 14 deletions models/config.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,58 @@
from dataclasses import dataclass
from dataclasses import dataclass, field


@dataclass
class VLMConfig:
vit_hidden_dim: int = 768
vit_inter_dim: int = 4 * vit_hidden_dim
vit_patch_size: int = 16
vit_img_size: int = 224
vit_img_size: int = 256
vit_n_heads: int = 12
vit_dropout: float = 0.0
vit_n_blocks: int = 12
vit_ln_eps: float = 1e-6
vit_cls_flag: bool = False
vit_model_type: str = 'google/siglip-base-patch16-224'
vit_model_type: str = 'google/siglip2-base-patch16-256'

lm_hidden_dim: int = 576
lm_inter_dim: int = 1536
lm_rms_eps: float = 1e-5
lm_re_base: int = 100000
lm_max_position_embeddings: int = 8192
lm_vocab_size: int = 49152
lm_base_vocab_size: int = 49152
extra_token_amount: int = 1 # Number of extra tokens for the VLM (image start, image end, image token)
lm_vocab_size: int = lm_base_vocab_size + extra_token_amount # Not a great way to do this, but it works for now (vlm_extra_tokens cannot be a dict, since this is mutable, and a Field has no len() function)
lm_n_heads: int = 9
lm_n_kv_heads: int = 3
lm_dropout: float = 0.0
lm_n_blocks: int = 30
lm_attn_scaling: float = 1.0
IMAGE_TOKEN_LENGTH: int = 49
TOTAL_SEQUENCE_LENGTH: int = 128
lm_max_length: int = TOTAL_SEQUENCE_LENGTH - IMAGE_TOKEN_LENGTH # Maximum length for the language model, derived from TOTAL_SEQUENCE_LENGTH and IMAGE_TOKEN_LENGTH
lm_max_length: int = 512
lm_use_tokens: bool = False # Decide if the LM expects tokens or embeddings as input (if using as a backbone for the VLM, set to False)
lm_tie_weights: bool = True # Decide if you want to tie the LM Head weight to the token embedding weights
lm_model_type: str = 'HuggingFaceTB/SmolLM2-135M'
lm_tokenizer: str = 'HuggingFaceTB/cosmo2-tokenizer'
lm_model_type: str = 'HuggingFaceTB/SmolLM2-360M-Instruct'
lm_tokenizer: str = 'HuggingFaceTB/SmolLM2-360M-Instruct'
lm_eos_token_id: int = 0

mp_pixel_shuffle_factor: int = 2
mp_image_token_length: int = 64

vlm_extra_tokens: dict[str, str] = field(default_factory=lambda: {"image_token": "<|image|>"})#, "boi_token": "<|image_start|>", "eoi_token": "<|image_end|>"})
vlm_load_backbone_weights: bool = True
vlm_checkpoint_path: str = 'checkpoints'
hf_repo_name: str = 'nanoVLM'


@dataclass
class TrainConfig:
lr_mp: float = 2e-3
lr_backbones: float = 1e-4
lr_mp: float = 0.003
lr_backbones: float = 5e-5
data_cutoff_idx: int = None
val_ratio: float = 0.025
batch_size: int = 256
gradient_accumulation_steps: int = 1
batch_size: int = 32
gradient_accumulation_steps: int = 4
mmstar_batch_size: int = 32
max_grad_norm: float = None
max_grad_norm: float = 1.0
eval_in_epochs: bool = True
eval_interval: int = 250
epochs: int = 5
Expand Down
Loading