Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for device_map use #264

Merged
merged 2 commits into from
Apr 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
author = "The Inseq Team"

# The short X.Y version
version = "0.6"
version = "0.7"
# The full version, including alpha/beta/rc tags
release = "0.6.0"
release = "0.7.0.dev0"


# Prefix link to point to master, comment this during version release and uncomment below line
Expand Down
1 change: 1 addition & 0 deletions inseq/models/attribution_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ def __init__(self, **kwargs) -> None:
self.pad_token: Optional[str] = None
self.embed_scale: Optional[float] = None
self._device: Optional[str] = None
self.device_map: Optional[dict[str, Union[str, int, torch.device]]] = None
self.attribution_method: Optional[FeatureAttribution] = None
self.is_hooked: bool = False
self._default_attributed_fn_id: str = "probability"
Expand Down
12 changes: 9 additions & 3 deletions inseq/models/huggingface_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,9 @@ def __init__(
self.embed_scale = 1.0
self.encoder_int_embeds = None
self.decoder_int_embeds = None
self.device_map = None
if hasattr(self.model, "hf_device_map") and self.model.hf_device_map is not None:
self.device_map = self.model.hf_device_map
self.is_encoder_decoder = self.model.config.is_encoder_decoder
self.configure_embeddings_scale()
self.setup(device, attribution_method, **kwargs)
Expand Down Expand Up @@ -162,16 +165,19 @@ def device(self, new_device: str) -> None:
is_loaded_in_8bit = getattr(self.model, "is_loaded_in_8bit", False)
is_loaded_in_4bit = getattr(self.model, "is_loaded_in_4bit", False)
is_quantized = is_loaded_in_8bit or is_loaded_in_4bit
has_device_map = self.device_map is not None

# Enable compatibility with 8bit models
if self.model:
if not is_quantized:
self.model.to(self._device)
else:
if is_quantized:
mode = "8bit" if is_loaded_in_8bit else "4bit"
logger.warning(
f"The model is loaded in {mode} mode. The device cannot be changed after loading the model."
)
elif has_device_map:
logger.warning("The model is loaded with a device map. The device cannot be changed after loading.")
else:
self.model.to(self._device)

@abstractmethod
def configure_embeddings_scale(self) -> None:
Expand Down
2 changes: 2 additions & 0 deletions inseq/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
)
from .hooks import StackFrame, get_post_variable_assignment_hook
from .import_utils import (
is_accelerate_available,
is_captum_available,
is_datasets_available,
is_ipywidgets_available,
Expand Down Expand Up @@ -130,4 +131,5 @@
"validate_indices",
"pad_with_nan",
"recursive_get_submodule",
"is_accelerate_available",
]
5 changes: 5 additions & 0 deletions inseq/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
_captum_available = find_spec("captum") is not None
_joblib_available = find_spec("joblib") is not None
_nltk_available = find_spec("nltk") is not None
_accelerate_available = find_spec("accelerate") is not None


def is_ipywidgets_available():
Expand Down Expand Up @@ -40,3 +41,7 @@ def is_joblib_available():

def is_nltk_available():
return _nltk_available


def is_accelerate_available():
return _accelerate_available
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "inseq"
version = "0.6.0"
version = "0.7.0.dev0"
description = "Interpretability for Sequence Generation Models 🔍"
readme = "README.md"
requires-python = ">=3.9"
Expand Down
Loading