From 4967d4824ffa3e484f340e150c956c4a8c6549e9 Mon Sep 17 00:00:00 2001 From: Gabriele Sarti Date: Mon, 22 Apr 2024 13:33:17 +0200 Subject: [PATCH 1/2] Add device_map support --- docs/source/conf.py | 4 ++-- inseq/models/attribution_model.py | 11 ++++++++++- inseq/models/huggingface_model.py | 3 +++ inseq/utils/__init__.py | 2 ++ inseq/utils/import_utils.py | 5 +++++ pyproject.toml | 2 +- 6 files changed, 23 insertions(+), 4 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 28eed915..d2805c1e 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -25,9 +25,9 @@ author = "The Inseq Team" # The short X.Y version -version = "0.6" +version = "0.7" # The full version, including alpha/beta/rc tags -release = "0.6.0" +release = "0.7.0.dev0" # Prefix link to point to master, comment this during version release and uncomment below line diff --git a/inseq/models/attribution_model.py b/inseq/models/attribution_model.py index b96db5c2..49c55d2c 100644 --- a/inseq/models/attribution_model.py +++ b/inseq/models/attribution_model.py @@ -22,6 +22,7 @@ format_input_texts, get_adjusted_alignments, get_default_device, + is_accelerate_available, isnotebook, pretty_tensor, ) @@ -219,6 +220,7 @@ def __init__(self, **kwargs) -> None: self.pad_token: Optional[str] = None self.embed_scale: Optional[float] = None self._device: Optional[str] = None + self.device_map: Optional[dict[str, Union[str, int, torch.device]]] = None self.attribution_method: Optional[FeatureAttribution] = None self.is_hooked: bool = False self._default_attributed_fn_id: str = "probability" @@ -234,7 +236,14 @@ def device(self, new_device: str) -> None: check_device(new_device) self._device = new_device if self.model: - self.model.to(self._device) + if self.device_map is None: + self.model.to(self._device) + elif is_accelerate_available(): + from accelerate import dispatch_model + + self.model = dispatch_model(self.model, device_map=self.device_map) + else: + raise ImportError("Accelerate is not available, but device_map is set.") def setup(self, device: Optional[str] = None, attribution_method: Optional[str] = None, **kwargs) -> None: """Move the model to device and in eval mode.""" diff --git a/inseq/models/huggingface_model.py b/inseq/models/huggingface_model.py index c7416a58..1b068316 100644 --- a/inseq/models/huggingface_model.py +++ b/inseq/models/huggingface_model.py @@ -127,6 +127,9 @@ def __init__( self.embed_scale = 1.0 self.encoder_int_embeds = None self.decoder_int_embeds = None + self.device_map = None + if hasattr(self.model, "hf_device_map") and self.model.hf_device_map is not None: + self.device_map = self.model.hf_device_map self.is_encoder_decoder = self.model.config.is_encoder_decoder self.configure_embeddings_scale() self.setup(device, attribution_method, **kwargs) diff --git a/inseq/utils/__init__.py b/inseq/utils/__init__.py index f632ba32..9eb39aba 100644 --- a/inseq/utils/__init__.py +++ b/inseq/utils/__init__.py @@ -10,6 +10,7 @@ ) from .hooks import StackFrame, get_post_variable_assignment_hook from .import_utils import ( + is_accelerate_available, is_captum_available, is_datasets_available, is_ipywidgets_available, @@ -130,4 +131,5 @@ "validate_indices", "pad_with_nan", "recursive_get_submodule", + "is_accelerate_available", ] diff --git a/inseq/utils/import_utils.py b/inseq/utils/import_utils.py index 2a1ccc2d..e8ae455e 100644 --- a/inseq/utils/import_utils.py +++ b/inseq/utils/import_utils.py @@ -8,6 +8,7 @@ _captum_available = find_spec("captum") is not None _joblib_available = find_spec("joblib") is not None _nltk_available = find_spec("nltk") is not None +_accelerate_available = find_spec("accelerate") is not None def is_ipywidgets_available(): @@ -40,3 +41,7 @@ def is_joblib_available(): def is_nltk_available(): return _nltk_available + + +def is_accelerate_available(): + return _accelerate_available diff --git a/pyproject.toml b/pyproject.toml index 8b0bdd05..32592ee8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "inseq" -version = "0.6.0" +version = "0.7.0.dev0" description = "Interpretability for Sequence Generation Models 🔍" readme = "README.md" requires-python = ">=3.9" From 4b2e727c7f03aebca34a9593faadb4eaf4928b04 Mon Sep 17 00:00:00 2001 From: Gabriele Sarti Date: Tue, 23 Apr 2024 22:45:09 +0200 Subject: [PATCH 2/2] Fix device setter in HF model --- inseq/models/attribution_model.py | 10 +--------- inseq/models/huggingface_model.py | 9 ++++++--- 2 files changed, 7 insertions(+), 12 deletions(-) diff --git a/inseq/models/attribution_model.py b/inseq/models/attribution_model.py index 49c55d2c..2f259a4c 100644 --- a/inseq/models/attribution_model.py +++ b/inseq/models/attribution_model.py @@ -22,7 +22,6 @@ format_input_texts, get_adjusted_alignments, get_default_device, - is_accelerate_available, isnotebook, pretty_tensor, ) @@ -236,14 +235,7 @@ def device(self, new_device: str) -> None: check_device(new_device) self._device = new_device if self.model: - if self.device_map is None: - self.model.to(self._device) - elif is_accelerate_available(): - from accelerate import dispatch_model - - self.model = dispatch_model(self.model, device_map=self.device_map) - else: - raise ImportError("Accelerate is not available, but device_map is set.") + self.model.to(self._device) def setup(self, device: Optional[str] = None, attribution_method: Optional[str] = None, **kwargs) -> None: """Move the model to device and in eval mode.""" diff --git a/inseq/models/huggingface_model.py b/inseq/models/huggingface_model.py index 1b068316..d6cc3f6b 100644 --- a/inseq/models/huggingface_model.py +++ b/inseq/models/huggingface_model.py @@ -165,16 +165,19 @@ def device(self, new_device: str) -> None: is_loaded_in_8bit = getattr(self.model, "is_loaded_in_8bit", False) is_loaded_in_4bit = getattr(self.model, "is_loaded_in_4bit", False) is_quantized = is_loaded_in_8bit or is_loaded_in_4bit + has_device_map = self.device_map is not None # Enable compatibility with 8bit models if self.model: - if not is_quantized: - self.model.to(self._device) - else: + if is_quantized: mode = "8bit" if is_loaded_in_8bit else "4bit" logger.warning( f"The model is loaded in {mode} mode. The device cannot be changed after loading the model." ) + elif has_device_map: + logger.warning("The model is loaded with a device map. The device cannot be changed after loading.") + else: + self.model.to(self._device) @abstractmethod def configure_embeddings_scale(self) -> None: