Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix adapter v2 llm.int8 inference #323

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion finetune/adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def main(

fabric = L.Fabric(
accelerator="cuda",
devices=1,
devices=devices,
strategy=(DeepSpeedStrategy(config=ds_config) if devices > 1 else "auto"),
precision="bf16-true",
)
Expand Down
13 changes: 10 additions & 3 deletions generate/adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def main(
device=fabric.device, dtype=dtype, quantization_mode=quantize
):
model = LLaMA.from_name(name)
add_adapter_v2_parameters_to_linear_layers(model)
add_adapter_v2_parameters_to_linear_layers(model, dtype=dtype, quantize=quantize)

# 1. Load the pretrained weights
model.load_state_dict(pretrained_checkpoint, strict=False)
Expand All @@ -84,10 +84,17 @@ def main(
prompt_length = encoded.size(0)

t0 = time.perf_counter()
y = generate(model, encoded, max_new_tokens, temperature=temperature, top_k=top_k, eos_id=tokenizer.eos_id)
y = generate(
model,
idx=encoded,
max_seq_length=max_new_tokens,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the reasoning behind this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using named arguments for easier debugging I guess

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean passing max_seq_length=max_new_tokens. This would limit it a lot as by default it will equal to the block_size

Copy link
Contributor

@rasbt rasbt Jun 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@carmocca I don't know, to be honest. I adopted this from the regular LLaMA adapter in the adapter.py script when I originally implemented adapter_v2.py. It's also like this in generate/lora.py

I'd say it's okay to leave this for this PR, but then we maybe want to open an issue/PR to revisit this for ALL generate scripts?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The links you shared are doing it as I'm suggesting, to be clear, this is what I mean

Suggested change
max_seq_length=max_new_tokens,

max_new_tokens=max_new_tokens,
temperature=temperature,
top_k=top_k,
eos_id=tokenizer.eos_id
)
t = time.perf_counter() - t0

model.reset_cache()
output = tokenizer.decode(y)
output = output.split("### Response:")[1].strip()
print(output)
Expand Down
21 changes: 15 additions & 6 deletions lit_llama/adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from torch import Tensor
import torch.nn as nn
from torch.nn import functional as F
from lit_llama.quantization import Linear8bitLt

from lit_llama.adapter import LLaMA

Expand All @@ -26,20 +27,28 @@ def adapter_v2_state_from_state_dict(state_dict: dict) -> dict:


def adapter_v2_new_forward(self, input: Tensor) -> Tensor:
weight = self.weight
if isinstance(self, Linear8bitLt):
weight = self.dequantize(input.dtype)
return self.adapter_scale * (
F.linear(input, self.weight, self.bias) + self.adapter_bias
F.linear(input, weight, self.bias) + self.adapter_bias
)


def adapter_v2_linear_with_bias_and_scale(layer):
layer.adapter_bias = torch.nn.Parameter(torch.zeros(layer.weight.shape[0]), requires_grad=True)
layer.adapter_scale = torch.nn.Parameter(torch.ones(layer.weight.shape[0]), requires_grad=True)
def adapter_v2_linear_with_bias_and_scale(layer, dtype=None, quantize=False):
weight = layer.weight

if dtype is not None and quantize:
if isinstance(layer, Linear8bitLt):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the snippet above uses a try-catch, wouldn't you want it here too?

Copy link
Contributor

@rasbt rasbt Jun 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@carmocca Good point. And I just remember now why I didn't do it. I had some issues here with that.

So, some people may not be able to install bitsandbytes, and that shouldn't prevent people from using the adapter method without quantization. So, that's why added the quantize argument here. But if someone isusing the quantization flag, which sets quantize=True here AND bitsandbytes can not be imported, then it SHOULD fail, because otherwise it would run without quantization which is not what's intended when someone uses --quantize.

Now, in this case above where I used the try-except, I failed making it work with the quantize argument because I am overriding the default forward method, and I don't think it's easily possible to add that as an argument. I am actually not sure about that and would need some help here.

I think we actually want to remove the try-except above somehow as this is stupid and expensive if it has to fail to import something every time a forward call happens. Any ideas?

Screenshot 2023-06-07 at 4 24 53 PM

Copy link
Contributor

@carmocca carmocca Jun 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see (nice image!).

We could do this with functools.partial: partial(adapter_v2_new_forward, quantize=quantize)

weight = layer.dequantize(dtype)
layer.adapter_bias = torch.nn.Parameter(torch.zeros(weight.shape[0]), requires_grad=True)
layer.adapter_scale = torch.nn.Parameter(torch.ones(weight.shape[0]), requires_grad=True)
bound_method = adapter_v2_new_forward.__get__(layer, layer.__class__)
setattr(layer, 'forward', bound_method)
return layer


def add_adapter_v2_parameters_to_linear_layers(model):
def add_adapter_v2_parameters_to_linear_layers(model, dtype=None, quantize=False):
for module in model.modules():
if isinstance(module, nn.Linear):
adapter_v2_linear_with_bias_and_scale(module)
adapter_v2_linear_with_bias_and_scale(module, dtype=dtype, quantize=quantize)
11 changes: 11 additions & 0 deletions lit_llama/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,17 @@ def _quantize_weight(self, weight: torch.Tensor) -> None:
setattr(self.weight, "CB", CB)
setattr(self.weight, "SCB", SCB)

def dequantize(self, dtype):
if dtype not in [torch.bfloat16, torch.float16, torch.float32]:
raise ValueError(f"Invalid dtype: {dtype}. Allowed dtypes are: bfloat16, float16, float32")
weight_CB = self.weight.CB
weight_SCB = self.weight.SCB
# Modify SBC shape if it doesn't match CB
if weight_CB.shape[1] != weight_SCB.shape[0]:
weight_SCB = weight_SCB.view(weight_SCB.shape[0], 1)
result = (weight_CB * weight_SCB) / 127
result = result.to(dtype)
return result

if triton is not None:
# This is adapted from the OpenAI Triton matmul example.
Expand Down