Skip to content

Commit d58d5ab

Browse files
committed
Add some docstrings
1 parent 555c360 commit d58d5ab

File tree

1 file changed

+80
-1
lines changed

1 file changed

+80
-1
lines changed

exllamav2/model.py

+80-1
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,32 @@ def load(
244244
callback_gen: Callable[[int, int], None] | None = None,
245245
progress: bool = False
246246
):
247+
"""
248+
Load model, regular manual split mode.
249+
250+
:param gpu_split:
251+
List of VRAM allocations for weights and fixed buffers per GPU. Does not account for the size of the cache
252+
which must be allocated with reference to the model subsequently and whose split across GPUs will depend
253+
on which devices end up receiving which attention layers.
254+
255+
If None, only the first GPU is used.
256+
257+
:param lazy:
258+
Only set the device map according to the split, but don't actually load any of the modules. Modules can
259+
subsequently be loaded and unloaded one by one for layer-streaming mode.
260+
261+
:param stats:
262+
Legacy, unused
263+
264+
:param callback:
265+
Callable function that triggers after each layer has loaded, for progress update etc.
266+
267+
:param callback_gen:
268+
Same as callback, but for use by async functions
269+
270+
:param progress:
271+
If True, create a rich progress bar in the console while loading. Cannot be used with callbacks
272+
"""
247273

248274
if progress:
249275
progressbar = get_basic_progress()
@@ -270,7 +296,6 @@ def load_gen(
270296
callback: Callable[[int, int], None] | None = None,
271297
callback_gen: Callable[[int, int], None] | None = None
272298
):
273-
274299
with torch.inference_mode():
275300

276301
stats_ = self.set_device_map(gpu_split or [99999])
@@ -306,7 +331,34 @@ def load_tp(
306331
expect_cache_tokens: int = 0,
307332
expect_cache_base: type = None
308333
):
334+
"""
335+
Load model, tensor-parallel mode.
336+
337+
:param gpu_split:
338+
List of VRAM allocations per GPU. The loader attempts to balance tensor splits to stay within these
339+
allocations, accounting for an uneven distribution of attention heads and the expected size of the cache.
340+
341+
If None, the loader attempts to use all available GPUs and creates a split based on the currently available
342+
VRAM according to nvidia-smi etc.
343+
344+
:param callback:
345+
Callable function that triggers after each layer has loaded, for progress update etc.
346+
347+
:param callback_gen:
348+
Same as callback, but for use by async functions
309349
350+
:param progress:
351+
If True, create a rich progress bar in the console while loading. Cannot be used with callbacks
352+
353+
:param expect_cache_tokens:
354+
Expected size of the cache, in tokens (i.e. max_seq_len * max_batch_size, or just the cache size for use
355+
with the dynamic generator) to inform the automatic tensor split. If not provided, the configured
356+
max_seq_len for the model is assumed.
357+
358+
:param expect_cache_base:
359+
Cache type to expect, e.g. ExLlamaV2Cache_Q6. Also informs the tensor split. If not provided, FP16 cache
360+
is assumed.
361+
"""
310362
if progress:
311363
progressbar = get_basic_progress()
312364
progressbar.start()
@@ -400,7 +452,31 @@ def load_autosplit(
400452
callback_gen: Callable[[int, int], None] | None = None,
401453
progress: bool = False
402454
):
455+
"""
456+
Load model, auto-split mode. This mode loads the model and builds the cache in parallel, using available
457+
devices in turn and moving on to the next device whenever the previous one is full.
458+
459+
:param cache:
460+
Cache constructed with lazy = True. Actual tensor allocation for the cache will happen while loading the
461+
model.
462+
463+
:param reserve_vram:
464+
Number of bytes to reserve on each device, either for all devices (as an int) or per-device (as a list).
403465
466+
:param last_id_only:
467+
If True, model will be loaded in a mode that does can only output one set of logits (i.e. one token
468+
position) per forward pass. This conserves memory if the model is only to be used for generating text and
469+
not e.g. perplexity measurement.
470+
471+
:param callback:
472+
Callable function that triggers after each layer has loaded, for progress update etc.
473+
474+
:param callback_gen:
475+
Same as callback, but for use by async functions
476+
477+
:param progress:
478+
If True, create a rich progress bar in the console while loading. Cannot be used with callbacks
479+
"""
404480
if progress:
405481
progressbar = get_basic_progress()
406482
progressbar.start()
@@ -569,6 +645,9 @@ def load_autosplit_gen(
569645

570646

571647
def unload(self):
648+
"""
649+
Unloads the model and frees all unmanaged resources.
650+
"""
572651

573652
for module in self.modules:
574653
module.unload()

0 commit comments

Comments
 (0)