diff --git a/src/compressed_tensors/transform/factory/hadamard.py b/src/compressed_tensors/transform/factory/hadamard.py index c14da51f..06e0886a 100644 --- a/src/compressed_tensors/transform/factory/hadamard.py +++ b/src/compressed_tensors/transform/factory/hadamard.py @@ -22,7 +22,7 @@ apply_transform_weight, get_matrix_size, ) -from compressed_tensors.utils import get_offloaded_device +from compressed_tensors.utils import get_execution_device, get_offloaded_device from compressed_tensors.utils.helpers import ParameterizedDefaultDict from torch import Tensor, device, dtype from torch.nn import Linear, Module, Parameter @@ -54,12 +54,20 @@ def create_transform(self, module: Module, args: TransformArgs): size = get_matrix_size(module, args.location) dtype = module.weight.dtype device = get_offloaded_device(module) + exec_device = get_execution_device(module) - weight = self.weights[size, dtype, device] + weight = self.weights.get(size, dtype, device, construct_device=exec_device) return HadamardTransform(weight, args) - def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter: - data = deterministic_hadamard_matrix(size, dtype, device) + def _create_weight( + self, + size: int, + dtype: dtype, + device: device, + construct_device: device, + ) -> Parameter: + # construct on execution device, cache on offload device + data = deterministic_hadamard_matrix(size, torch.float32, construct_device) data = data.to(dtype=dtype, device=device) return Parameter(data, requires_grad=self.scheme.requires_grad) diff --git a/src/compressed_tensors/transform/factory/random_hadamard.py b/src/compressed_tensors/transform/factory/random_hadamard.py index 78fb6975..2b17be54 100644 --- a/src/compressed_tensors/transform/factory/random_hadamard.py +++ b/src/compressed_tensors/transform/factory/random_hadamard.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import torch from compressed_tensors.transform import HadamardFactory, TransformFactory from compressed_tensors.transform.utils.hadamard import random_hadamard_matrix from torch import device, dtype @@ -28,7 +29,16 @@ class RandomHadamardFactory(HadamardFactory): :param seed: random seed used to transform weight randomization """ - def _create_weight(self, size: int, dtype: dtype, device: device) -> Parameter: - data = random_hadamard_matrix(size, dtype, device, self.generator) + def _create_weight( + self, + size: int, + dtype: dtype, + device: device, + construct_device: device, + ) -> Parameter: + # construct on execution device, cache on offload device + data = random_hadamard_matrix( + size, torch.float32, construct_device, self.generator + ) data = data.to(dtype=dtype, device=device) return Parameter(data, requires_grad=self.scheme.requires_grad) diff --git a/src/compressed_tensors/utils/helpers.py b/src/compressed_tensors/utils/helpers.py index d8898ae4..c06d3e3f 100644 --- a/src/compressed_tensors/utils/helpers.py +++ b/src/compressed_tensors/utils/helpers.py @@ -373,11 +373,16 @@ class ParameterizedDefaultDict(dict): def __init__(self, default_factory: Callable[[Any], Any]): self.default_factory = default_factory + self._kwargs = {} - def __missing__(self, key): + def __missing__(self, key: Any) -> Any: if isinstance(key, tuple): - value = self.default_factory(*key) + value = self.default_factory(*key, **self._kwargs) else: - value = self.default_factory(key) + value = self.default_factory(key, **self._kwargs) self[key] = value return value + + def get(self, *args, **kwargs) -> Any: + with patch_attr(self, "_kwargs", kwargs): + return self[args]