@@ -244,6 +244,32 @@ def load(
244
244
callback_gen : Callable [[int , int ], None ] | None = None ,
245
245
progress : bool = False
246
246
):
247
+ """
248
+ Load model, regular manual split mode.
249
+
250
+ :param gpu_split:
251
+ List of VRAM allocations for weights and fixed buffers per GPU. Does not account for the size of the cache
252
+ which must be allocated with reference to the model subsequently and whose split across GPUs will depend
253
+ on which devices end up receiving which attention layers.
254
+
255
+ If None, only the first GPU is used.
256
+
257
+ :param lazy:
258
+ Only set the device map according to the split, but don't actually load any of the modules. Modules can
259
+ subsequently be loaded and unloaded one by one for layer-streaming mode.
260
+
261
+ :param stats:
262
+ Legacy, unused
263
+
264
+ :param callback:
265
+ Callable function that triggers after each layer has loaded, for progress update etc.
266
+
267
+ :param callback_gen:
268
+ Same as callback, but for use by async functions
269
+
270
+ :param progress:
271
+ If True, create a rich progress bar in the console while loading. Cannot be used with callbacks
272
+ """
247
273
248
274
if progress :
249
275
progressbar = get_basic_progress ()
@@ -270,7 +296,6 @@ def load_gen(
270
296
callback : Callable [[int , int ], None ] | None = None ,
271
297
callback_gen : Callable [[int , int ], None ] | None = None
272
298
):
273
-
274
299
with torch .inference_mode ():
275
300
276
301
stats_ = self .set_device_map (gpu_split or [99999 ])
@@ -306,7 +331,34 @@ def load_tp(
306
331
expect_cache_tokens : int = 0 ,
307
332
expect_cache_base : type = None
308
333
):
334
+ """
335
+ Load model, tensor-parallel mode.
336
+
337
+ :param gpu_split:
338
+ List of VRAM allocations per GPU. The loader attempts to balance tensor splits to stay within these
339
+ allocations, accounting for an uneven distribution of attention heads and the expected size of the cache.
340
+
341
+ If None, the loader attempts to use all available GPUs and creates a split based on the currently available
342
+ VRAM according to nvidia-smi etc.
343
+
344
+ :param callback:
345
+ Callable function that triggers after each layer has loaded, for progress update etc.
346
+
347
+ :param callback_gen:
348
+ Same as callback, but for use by async functions
309
349
350
+ :param progress:
351
+ If True, create a rich progress bar in the console while loading. Cannot be used with callbacks
352
+
353
+ :param expect_cache_tokens:
354
+ Expected size of the cache, in tokens (i.e. max_seq_len * max_batch_size, or just the cache size for use
355
+ with the dynamic generator) to inform the automatic tensor split. If not provided, the configured
356
+ max_seq_len for the model is assumed.
357
+
358
+ :param expect_cache_base:
359
+ Cache type to expect, e.g. ExLlamaV2Cache_Q6. Also informs the tensor split. If not provided, FP16 cache
360
+ is assumed.
361
+ """
310
362
if progress :
311
363
progressbar = get_basic_progress ()
312
364
progressbar .start ()
@@ -400,7 +452,31 @@ def load_autosplit(
400
452
callback_gen : Callable [[int , int ], None ] | None = None ,
401
453
progress : bool = False
402
454
):
455
+ """
456
+ Load model, auto-split mode. This mode loads the model and builds the cache in parallel, using available
457
+ devices in turn and moving on to the next device whenever the previous one is full.
458
+
459
+ :param cache:
460
+ Cache constructed with lazy = True. Actual tensor allocation for the cache will happen while loading the
461
+ model.
462
+
463
+ :param reserve_vram:
464
+ Number of bytes to reserve on each device, either for all devices (as an int) or per-device (as a list).
403
465
466
+ :param last_id_only:
467
+ If True, model will be loaded in a mode that does can only output one set of logits (i.e. one token
468
+ position) per forward pass. This conserves memory if the model is only to be used for generating text and
469
+ not e.g. perplexity measurement.
470
+
471
+ :param callback:
472
+ Callable function that triggers after each layer has loaded, for progress update etc.
473
+
474
+ :param callback_gen:
475
+ Same as callback, but for use by async functions
476
+
477
+ :param progress:
478
+ If True, create a rich progress bar in the console while loading. Cannot be used with callbacks
479
+ """
404
480
if progress :
405
481
progressbar = get_basic_progress ()
406
482
progressbar .start ()
@@ -569,6 +645,9 @@ def load_autosplit_gen(
569
645
570
646
571
647
def unload (self ):
648
+ """
649
+ Unloads the model and frees all unmanaged resources.
650
+ """
572
651
573
652
for module in self .modules :
574
653
module .unload ()
0 commit comments