diff --git a/lit_llama/adapter_v2.py b/lit_llama/adapter_v2.py index 368e695f..fb9d8afa 100644 --- a/lit_llama/adapter_v2.py +++ b/lit_llama/adapter_v2.py @@ -32,8 +32,14 @@ def adapter_v2_new_forward(self, input: Tensor) -> Tensor: def adapter_v2_linear_with_bias_and_scale(layer): - layer.adapter_bias = torch.nn.Parameter(torch.zeros(layer.weight.shape[0]), requires_grad=True) - layer.adapter_scale = torch.nn.Parameter(torch.ones(layer.weight.shape[0]), requires_grad=True) + layer.adapter_bias = torch.nn.Parameter( + torch.zeros(layer.weight.shape[0], dtype=layer.weight.dtype), + requires_grad=False, + ) + layer.adapter_scale = torch.nn.Parameter( + torch.ones(layer.weight.shape[0], dtype=layer.weight.dtype), + requires_grad=False, + ) bound_method = adapter_v2_new_forward.__get__(layer, layer.__class__) setattr(layer, 'forward', bound_method) return layer