-
Notifications
You must be signed in to change notification settings - Fork 520
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix adapter v2 llm.int8 inference #323
base: main
Are you sure you want to change the base?
Changes from all commits
b77cbcc
2824293
2a57575
961a522
3d88c3c
ae36781
257bcda
c93c39f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -26,20 +26,34 @@ def adapter_v2_state_from_state_dict(state_dict: dict) -> dict: | |||||
|
||||||
|
||||||
def adapter_v2_new_forward(self, input: Tensor) -> Tensor: | ||||||
weight = self.weight | ||||||
|
||||||
try: | ||||||
from lit_llama.quantization import Linear8bitLt | ||||||
if isinstance(self, Linear8bitLt): | ||||||
weight = self.dequantize(input.dtype) | ||||||
except: | ||||||
carmocca marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
None | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's more common to
Suggested change
|
||||||
return self.adapter_scale * ( | ||||||
F.linear(input, self.weight, self.bias) + self.adapter_bias | ||||||
F.linear(input, weight, self.bias) + self.adapter_bias | ||||||
) | ||||||
|
||||||
|
||||||
def adapter_v2_linear_with_bias_and_scale(layer): | ||||||
layer.adapter_bias = torch.nn.Parameter(torch.zeros(layer.weight.shape[0]), requires_grad=True) | ||||||
layer.adapter_scale = torch.nn.Parameter(torch.ones(layer.weight.shape[0]), requires_grad=True) | ||||||
def adapter_v2_linear_with_bias_and_scale(layer, dtype=None, quantize=False): | ||||||
weight = layer.weight | ||||||
|
||||||
if dtype is not None and quantize: | ||||||
from lit_llama.quantization import Linear8bitLt | ||||||
if isinstance(layer, Linear8bitLt): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If the snippet above uses a try-catch, wouldn't you want it here too? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @carmocca Good point. And I just remember now why I didn't do it. I had some issues here with that. So, some people may not be able to install Now, in this case above where I used the try-except, I failed making it work with the I think we actually want to remove the try-except above somehow as this is stupid and expensive if it has to fail to import something every time a forward call happens. Any ideas? ![]() There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see (nice image!). We could do this with |
||||||
weight = layer.dequantize(dtype) | ||||||
layer.adapter_bias = torch.nn.Parameter(torch.zeros(weight.shape[0]), requires_grad=True) | ||||||
layer.adapter_scale = torch.nn.Parameter(torch.ones(weight.shape[0]), requires_grad=True) | ||||||
bound_method = adapter_v2_new_forward.__get__(layer, layer.__class__) | ||||||
setattr(layer, 'forward', bound_method) | ||||||
return layer | ||||||
|
||||||
|
||||||
def add_adapter_v2_parameters_to_linear_layers(model): | ||||||
def add_adapter_v2_parameters_to_linear_layers(model, dtype=None, quantize=False): | ||||||
for module in model.modules(): | ||||||
if isinstance(module, nn.Linear): | ||||||
adapter_v2_linear_with_bias_and_scale(module) | ||||||
adapter_v2_linear_with_bias_and_scale(module, dtype=dtype, quantize=quantize) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the reasoning behind this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using named arguments for easier debugging I guess
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean passing
max_seq_length=max_new_tokens
. This would limit it a lot as by default it will equal to theblock_size
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@carmocca I don't know, to be honest. I adopted this from the regular LLaMA adapter in the adapter.py script when I originally implemented adapter_v2.py. It's also like this in
generate/lora.py
I'd say it's okay to leave this for this PR, but then we maybe want to open an issue/PR to revisit this for ALL generate scripts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The links you shared are doing it as I'm suggesting, to be clear, this is what I mean