From b77cbccf258219aeae9f35a1583688ae41f8feb3 Mon Sep 17 00:00:00 2001 From: Dior Miu Date: Wed, 24 May 2023 18:26:49 -0500 Subject: [PATCH 1/8] convert adapterv2 weight to input dtype --- lit_llama/adapter_v2.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/lit_llama/adapter_v2.py b/lit_llama/adapter_v2.py index 368e695f..02d66c42 100644 --- a/lit_llama/adapter_v2.py +++ b/lit_llama/adapter_v2.py @@ -2,6 +2,7 @@ from torch import Tensor import torch.nn as nn from torch.nn import functional as F +from bitsandbytes.nn import SwitchBackLinear, Linear8bitLt from lit_llama.adapter import LLaMA @@ -26,8 +27,11 @@ def adapter_v2_state_from_state_dict(state_dict: dict) -> dict: def adapter_v2_new_forward(self, input: Tensor) -> Tensor: + weight = self.weight + if isinstance(self, Linear8bitLt): + weight = SwitchBackLinear(self.in_features, self.out_features, dtype=input.dtype, device=input.device).weight return self.adapter_scale * ( - F.linear(input, self.weight, self.bias) + self.adapter_bias + F.linear(input, weight, self.bias) + self.adapter_bias ) From 2824293e3e86a7329dfe465868ed588198f9ca59 Mon Sep 17 00:00:00 2001 From: Dior Miu Date: Thu, 25 May 2023 10:06:22 -0500 Subject: [PATCH 2/8] dequantize adapter/layers weights --- generate/adapter_v2.py | 2 +- lit_llama/adapter_v2.py | 17 ++++++++++------- lit_llama/quantization.py | 11 +++++++++++ 3 files changed, 22 insertions(+), 8 deletions(-) diff --git a/generate/adapter_v2.py b/generate/adapter_v2.py index 95ae38e3..41f01982 100644 --- a/generate/adapter_v2.py +++ b/generate/adapter_v2.py @@ -65,7 +65,7 @@ def main( device=fabric.device, dtype=dtype, quantization_mode=quantize ): model = LLaMA.from_name(name) - add_adapter_v2_parameters_to_linear_layers(model) + add_adapter_v2_parameters_to_linear_layers(model, dtype) # 1. Load the pretrained weights model.load_state_dict(pretrained_checkpoint, strict=False) diff --git a/lit_llama/adapter_v2.py b/lit_llama/adapter_v2.py index 02d66c42..7b9b80cc 100644 --- a/lit_llama/adapter_v2.py +++ b/lit_llama/adapter_v2.py @@ -2,7 +2,7 @@ from torch import Tensor import torch.nn as nn from torch.nn import functional as F -from bitsandbytes.nn import SwitchBackLinear, Linear8bitLt +from lit_llama.quantization import Linear8bitLt from lit_llama.adapter import LLaMA @@ -29,21 +29,24 @@ def adapter_v2_state_from_state_dict(state_dict: dict) -> dict: def adapter_v2_new_forward(self, input: Tensor) -> Tensor: weight = self.weight if isinstance(self, Linear8bitLt): - weight = SwitchBackLinear(self.in_features, self.out_features, dtype=input.dtype, device=input.device).weight + weight = self.dequantize(input.dtype) return self.adapter_scale * ( F.linear(input, weight, self.bias) + self.adapter_bias ) -def adapter_v2_linear_with_bias_and_scale(layer): - layer.adapter_bias = torch.nn.Parameter(torch.zeros(layer.weight.shape[0]), requires_grad=True) - layer.adapter_scale = torch.nn.Parameter(torch.ones(layer.weight.shape[0]), requires_grad=True) +def adapter_v2_linear_with_bias_and_scale(layer, dtype): + weight = layer.weight + if isinstance(layer, Linear8bitLt): + weight = layer.dequantize(dtype) + layer.adapter_bias = torch.nn.Parameter(torch.zeros(weight.shape[0]), requires_grad=True) + layer.adapter_scale = torch.nn.Parameter(torch.ones(weight.shape[0]), requires_grad=True) bound_method = adapter_v2_new_forward.__get__(layer, layer.__class__) setattr(layer, 'forward', bound_method) return layer -def add_adapter_v2_parameters_to_linear_layers(model): +def add_adapter_v2_parameters_to_linear_layers(model, dtype): for module in model.modules(): if isinstance(module, nn.Linear): - adapter_v2_linear_with_bias_and_scale(module) + adapter_v2_linear_with_bias_and_scale(module, dtype) diff --git a/lit_llama/quantization.py b/lit_llama/quantization.py index 668a39f4..ba1df00c 100644 --- a/lit_llama/quantization.py +++ b/lit_llama/quantization.py @@ -74,6 +74,17 @@ def _quantize_weight(self, weight: torch.Tensor) -> None: setattr(self.weight, "CB", CB) setattr(self.weight, "SCB", SCB) + def dequantize(self, dtype): + if dtype not in [torch.bfloat16, torch.float16, torch.float32]: + raise ValueError(f"Invalid dtype: {dtype}. Allowed dtypes are: bfloat16, float16, float32") + weight_CB = self.weight.CB + weight_SCB = self.weight.SCB + # Modify SBC shape if it doesn't match CB + if weight_CB.shape[1] != weight_SCB.shape[0]: + weight_SCB = weight_SCB.view(weight_SCB.shape[0], 1) + result = (weight_CB * weight_SCB) / 127 + result = result.to(dtype) + return result if triton is not None: # This is adapted from the OpenAI Triton matmul example. From 2a575754ae6d78c76f482d6e4a35dde1506aa584 Mon Sep 17 00:00:00 2001 From: Sebastian Raschka Date: Fri, 26 May 2023 17:02:23 -0500 Subject: [PATCH 3/8] Update adapter_v2.py Small fixes to make the generate function work. --- generate/adapter_v2.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/generate/adapter_v2.py b/generate/adapter_v2.py index 41f01982..1f2c3013 100644 --- a/generate/adapter_v2.py +++ b/generate/adapter_v2.py @@ -84,10 +84,17 @@ def main( prompt_length = encoded.size(0) t0 = time.perf_counter() - y = generate(model, encoded, max_new_tokens, temperature=temperature, top_k=top_k, eos_id=tokenizer.eos_id) + y = generate( + model, + idx=encoded, + max_seq_length=max_new_tokens, + max_new_tokens=max_new_tokens, + temperature=temperature, + top_k=top_k, + eos_id=tokenizer.eos_id + ) t = time.perf_counter() - t0 - model.reset_cache() output = tokenizer.decode(y) output = output.split("### Response:")[1].strip() print(output) From 961a5229e2305e2805d4fad6ccf65915c37a6605 Mon Sep 17 00:00:00 2001 From: Sebastian Raschka Date: Fri, 2 Jun 2023 13:33:26 -0500 Subject: [PATCH 4/8] fix dtype usage --- datasets | 1 + finetune/adapter_v2.py | 2 +- generate/adapter_v2.py | 2 +- lit_llama/adapter_v2.py | 7 ++++--- 4 files changed, 7 insertions(+), 5 deletions(-) create mode 120000 datasets diff --git a/datasets b/datasets new file mode 120000 index 00000000..accf3ce9 --- /dev/null +++ b/datasets @@ -0,0 +1 @@ +/srv/data/datasets \ No newline at end of file diff --git a/finetune/adapter_v2.py b/finetune/adapter_v2.py index 2dd3c0bd..3db0e185 100644 --- a/finetune/adapter_v2.py +++ b/finetune/adapter_v2.py @@ -71,7 +71,7 @@ def main( fabric = L.Fabric( accelerator="cuda", - devices=1, + devices=devices, strategy=(DeepSpeedStrategy(config=ds_config) if devices > 1 else "auto"), precision="bf16-true", ) diff --git a/generate/adapter_v2.py b/generate/adapter_v2.py index 1f2c3013..b5175b35 100644 --- a/generate/adapter_v2.py +++ b/generate/adapter_v2.py @@ -65,7 +65,7 @@ def main( device=fabric.device, dtype=dtype, quantization_mode=quantize ): model = LLaMA.from_name(name) - add_adapter_v2_parameters_to_linear_layers(model, dtype) + add_adapter_v2_parameters_to_linear_layers(model, dtype=dtype) # 1. Load the pretrained weights model.load_state_dict(pretrained_checkpoint, strict=False) diff --git a/lit_llama/adapter_v2.py b/lit_llama/adapter_v2.py index 7b9b80cc..2666233c 100644 --- a/lit_llama/adapter_v2.py +++ b/lit_llama/adapter_v2.py @@ -35,9 +35,10 @@ def adapter_v2_new_forward(self, input: Tensor) -> Tensor: ) -def adapter_v2_linear_with_bias_and_scale(layer, dtype): +def adapter_v2_linear_with_bias_and_scale(layer, dtype=None): weight = layer.weight - if isinstance(layer, Linear8bitLt): + + if dtype is not None and isinstance(layer, Linear8bitLt): weight = layer.dequantize(dtype) layer.adapter_bias = torch.nn.Parameter(torch.zeros(weight.shape[0]), requires_grad=True) layer.adapter_scale = torch.nn.Parameter(torch.ones(weight.shape[0]), requires_grad=True) @@ -46,7 +47,7 @@ def adapter_v2_linear_with_bias_and_scale(layer, dtype): return layer -def add_adapter_v2_parameters_to_linear_layers(model, dtype): +def add_adapter_v2_parameters_to_linear_layers(model, dtype=None): for module in model.modules(): if isinstance(module, nn.Linear): adapter_v2_linear_with_bias_and_scale(module, dtype) From 3d88c3cd5fdacba064ca1582f4ca269309e66cf3 Mon Sep 17 00:00:00 2001 From: Sebastian Raschka Date: Fri, 2 Jun 2023 13:35:08 -0500 Subject: [PATCH 5/8] delete dataset symlink --- datasets | 1 - 1 file changed, 1 deletion(-) delete mode 120000 datasets diff --git a/datasets b/datasets deleted file mode 120000 index accf3ce9..00000000 --- a/datasets +++ /dev/null @@ -1 +0,0 @@ -/srv/data/datasets \ No newline at end of file From ae36781426b1c7ca55391d22bfbabe5a9e360145 Mon Sep 17 00:00:00 2001 From: Sebastian Raschka Date: Fri, 2 Jun 2023 14:00:08 -0500 Subject: [PATCH 6/8] make Linear8bitLt import optional --- generate/adapter_v2.py | 2 +- lit_llama/adapter_v2.py | 11 ++++++----- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/generate/adapter_v2.py b/generate/adapter_v2.py index b5175b35..3a366f98 100644 --- a/generate/adapter_v2.py +++ b/generate/adapter_v2.py @@ -65,7 +65,7 @@ def main( device=fabric.device, dtype=dtype, quantization_mode=quantize ): model = LLaMA.from_name(name) - add_adapter_v2_parameters_to_linear_layers(model, dtype=dtype) + add_adapter_v2_parameters_to_linear_layers(model, dtype=dtype, quantize=quantize) # 1. Load the pretrained weights model.load_state_dict(pretrained_checkpoint, strict=False) diff --git a/lit_llama/adapter_v2.py b/lit_llama/adapter_v2.py index 2666233c..a49b8a5d 100644 --- a/lit_llama/adapter_v2.py +++ b/lit_llama/adapter_v2.py @@ -35,11 +35,12 @@ def adapter_v2_new_forward(self, input: Tensor) -> Tensor: ) -def adapter_v2_linear_with_bias_and_scale(layer, dtype=None): +def adapter_v2_linear_with_bias_and_scale(layer, dtype=None, quantize=False): weight = layer.weight - if dtype is not None and isinstance(layer, Linear8bitLt): - weight = layer.dequantize(dtype) + if dtype is not None and quantize: + if isinstance(layer, Linear8bitLt): + weight = layer.dequantize(dtype) layer.adapter_bias = torch.nn.Parameter(torch.zeros(weight.shape[0]), requires_grad=True) layer.adapter_scale = torch.nn.Parameter(torch.ones(weight.shape[0]), requires_grad=True) bound_method = adapter_v2_new_forward.__get__(layer, layer.__class__) @@ -47,7 +48,7 @@ def adapter_v2_linear_with_bias_and_scale(layer, dtype=None): return layer -def add_adapter_v2_parameters_to_linear_layers(model, dtype=None): +def add_adapter_v2_parameters_to_linear_layers(model, dtype=None, quantize=False): for module in model.modules(): if isinstance(module, nn.Linear): - adapter_v2_linear_with_bias_and_scale(module, dtype) + adapter_v2_linear_with_bias_and_scale(module, dtype=dtype, quantize=quantize) From 257bcda7b61dab7dc2d9120177c530ef5f98d398 Mon Sep 17 00:00:00 2001 From: Sebastian Raschka Date: Fri, 2 Jun 2023 14:14:34 -0500 Subject: [PATCH 7/8] another attempt --- lit_llama/adapter_v2.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/lit_llama/adapter_v2.py b/lit_llama/adapter_v2.py index a49b8a5d..3afc7ceb 100644 --- a/lit_llama/adapter_v2.py +++ b/lit_llama/adapter_v2.py @@ -2,7 +2,6 @@ from torch import Tensor import torch.nn as nn from torch.nn import functional as F -from lit_llama.quantization import Linear8bitLt from lit_llama.adapter import LLaMA @@ -28,8 +27,13 @@ def adapter_v2_state_from_state_dict(state_dict: dict) -> dict: def adapter_v2_new_forward(self, input: Tensor) -> Tensor: weight = self.weight - if isinstance(self, Linear8bitLt): - weight = self.dequantize(input.dtype) + + try: + from lit_llama.quantization import Linear8bitLt + if isinstance(self, Linear8bitLt): + weight = self.dequantize(input.dtype) + except: + None return self.adapter_scale * ( F.linear(input, weight, self.bias) + self.adapter_bias ) @@ -39,6 +43,7 @@ def adapter_v2_linear_with_bias_and_scale(layer, dtype=None, quantize=False): weight = layer.weight if dtype is not None and quantize: + from lit_llama.quantization import Linear8bitLt if isinstance(layer, Linear8bitLt): weight = layer.dequantize(dtype) layer.adapter_bias = torch.nn.Parameter(torch.zeros(weight.shape[0]), requires_grad=True) From c93c39ffbe3ceada805a90626fe7fdc574b5fed6 Mon Sep 17 00:00:00 2001 From: Sebastian Raschka Date: Wed, 7 Jun 2023 15:56:58 -0500 Subject: [PATCH 8/8] add model.reset_cache() back --- generate/adapter_v2.py | 1 + 1 file changed, 1 insertion(+) diff --git a/generate/adapter_v2.py b/generate/adapter_v2.py index 3a366f98..d6d9ef6a 100644 --- a/generate/adapter_v2.py +++ b/generate/adapter_v2.py @@ -95,6 +95,7 @@ def main( ) t = time.perf_counter() - t0 + model.reset_cache() # technically only needed if max_seq_length varies output = tokenizer.decode(y) output = output.split("### Response:")[1].strip() print(output)