diff --git a/finetune/adapter_v2.py b/finetune/adapter_v2.py index 2dd3c0bd..3db0e185 100644 --- a/finetune/adapter_v2.py +++ b/finetune/adapter_v2.py @@ -71,7 +71,7 @@ def main( fabric = L.Fabric( accelerator="cuda", - devices=1, + devices=devices, strategy=(DeepSpeedStrategy(config=ds_config) if devices > 1 else "auto"), precision="bf16-true", ) diff --git a/generate/adapter_v2.py b/generate/adapter_v2.py index 95ae38e3..d6d9ef6a 100644 --- a/generate/adapter_v2.py +++ b/generate/adapter_v2.py @@ -65,7 +65,7 @@ def main( device=fabric.device, dtype=dtype, quantization_mode=quantize ): model = LLaMA.from_name(name) - add_adapter_v2_parameters_to_linear_layers(model) + add_adapter_v2_parameters_to_linear_layers(model, dtype=dtype, quantize=quantize) # 1. Load the pretrained weights model.load_state_dict(pretrained_checkpoint, strict=False) @@ -84,10 +84,18 @@ def main( prompt_length = encoded.size(0) t0 = time.perf_counter() - y = generate(model, encoded, max_new_tokens, temperature=temperature, top_k=top_k, eos_id=tokenizer.eos_id) + y = generate( + model, + idx=encoded, + max_seq_length=max_new_tokens, + max_new_tokens=max_new_tokens, + temperature=temperature, + top_k=top_k, + eos_id=tokenizer.eos_id + ) t = time.perf_counter() - t0 - model.reset_cache() + model.reset_cache() # technically only needed if max_seq_length varies output = tokenizer.decode(y) output = output.split("### Response:")[1].strip() print(output) diff --git a/lit_llama/adapter_v2.py b/lit_llama/adapter_v2.py index 368e695f..3afc7ceb 100644 --- a/lit_llama/adapter_v2.py +++ b/lit_llama/adapter_v2.py @@ -26,20 +26,34 @@ def adapter_v2_state_from_state_dict(state_dict: dict) -> dict: def adapter_v2_new_forward(self, input: Tensor) -> Tensor: + weight = self.weight + + try: + from lit_llama.quantization import Linear8bitLt + if isinstance(self, Linear8bitLt): + weight = self.dequantize(input.dtype) + except: + None return self.adapter_scale * ( - F.linear(input, self.weight, self.bias) + self.adapter_bias + F.linear(input, weight, self.bias) + self.adapter_bias ) -def adapter_v2_linear_with_bias_and_scale(layer): - layer.adapter_bias = torch.nn.Parameter(torch.zeros(layer.weight.shape[0]), requires_grad=True) - layer.adapter_scale = torch.nn.Parameter(torch.ones(layer.weight.shape[0]), requires_grad=True) +def adapter_v2_linear_with_bias_and_scale(layer, dtype=None, quantize=False): + weight = layer.weight + + if dtype is not None and quantize: + from lit_llama.quantization import Linear8bitLt + if isinstance(layer, Linear8bitLt): + weight = layer.dequantize(dtype) + layer.adapter_bias = torch.nn.Parameter(torch.zeros(weight.shape[0]), requires_grad=True) + layer.adapter_scale = torch.nn.Parameter(torch.ones(weight.shape[0]), requires_grad=True) bound_method = adapter_v2_new_forward.__get__(layer, layer.__class__) setattr(layer, 'forward', bound_method) return layer -def add_adapter_v2_parameters_to_linear_layers(model): +def add_adapter_v2_parameters_to_linear_layers(model, dtype=None, quantize=False): for module in model.modules(): if isinstance(module, nn.Linear): - adapter_v2_linear_with_bias_and_scale(module) + adapter_v2_linear_with_bias_and_scale(module, dtype=dtype, quantize=quantize) diff --git a/lit_llama/quantization.py b/lit_llama/quantization.py index 668a39f4..ba1df00c 100644 --- a/lit_llama/quantization.py +++ b/lit_llama/quantization.py @@ -74,6 +74,17 @@ def _quantize_weight(self, weight: torch.Tensor) -> None: setattr(self.weight, "CB", CB) setattr(self.weight, "SCB", SCB) + def dequantize(self, dtype): + if dtype not in [torch.bfloat16, torch.float16, torch.float32]: + raise ValueError(f"Invalid dtype: {dtype}. Allowed dtypes are: bfloat16, float16, float32") + weight_CB = self.weight.CB + weight_SCB = self.weight.SCB + # Modify SBC shape if it doesn't match CB + if weight_CB.shape[1] != weight_SCB.shape[0]: + weight_SCB = weight_SCB.view(weight_SCB.shape[0], 1) + result = (weight_CB * weight_SCB) / 127 + result = result.to(dtype) + return result if triton is not None: # This is adapted from the OpenAI Triton matmul example.