5
5
6
6
7
7
def concat_vectors_and_pad (vec_list , max_ ):
8
+ """
9
+ Concatenates a list of input vectors and pads them to match a specified maximum
10
+ length.
11
+
12
+ This function takes a list of input vectors, concatenates them along a specified
13
+ dimension (dim=0), and then pads the concatenated vector to achieve a specified
14
+ maximum length. The padding is done with zeros.
15
+
16
+ Args:
17
+ vec_list (list of torch.Tensor): List of input vectors to concatenate and pad.
18
+ max_ (int): The maximum length of the concatenated and padded vector.
19
+
20
+ Returns:
21
+ torch.Tensor: The concatenated and padded vector.
22
+
23
+ Raises:
24
+ AssertionError: If the length of 'vec_list' is not greater than 0, or if it
25
+ exceeds 'max_len', or if 'max_len' is not greater than 0.
26
+
27
+ Example:
28
+ >>> input_tensors = [torch.tensor([1, 2]), torch.tensor([3, 4, 5])]
29
+ >>> max_length = 5
30
+ >>> concatenated_padded = concat_vectors_and_pad(input_tensors, max_length)
31
+ >>> print(concatenated_padded)
32
+ tensor([1, 2, 3, 4, 5])
33
+ """
8
34
assert len (vec_list ) > 0
9
35
assert len (vec_list ) <= max_
10
36
assert max_ > 0
@@ -27,10 +53,29 @@ class LightwoodAutocast:
27
53
"""
28
54
Equivalent to torch.cuda.amp.autocast, but checks device compute capability
29
55
to activate the feature only when the GPU has tensor cores to leverage AMP.
56
+
57
+ **Attributes:**
58
+
59
+ * `active` (bool): Whether AMP is currently active. This attribute is at the class
60
+ level
61
+
62
+ **Usage:**
63
+
64
+ ```python
65
+ >>> import lightwood.helpers.torch as lt
66
+ >>> with lt.LightwoodAutocast():
67
+ ... # This code will be executed in AMP mode.
68
+ ... pass
30
69
"""
31
70
active = False
32
71
33
72
def __init__ (self , enabled = True ):
73
+ """
74
+ Initializes the context manager for Automatic Mixed Precision (AMP) functionality.
75
+
76
+ Args:
77
+ enabled (bool, optional): Whether to enable AMP. Defaults to True.
78
+ """
34
79
self .major = 0 # GPU major version
35
80
torch_version = [int (i ) for i in torch .__version__ .split ('.' )[:- 1 ]]
36
81
@@ -50,12 +95,18 @@ def __init__(self, enabled=True):
50
95
LightwoodAutocast .active = self ._enabled
51
96
52
97
def __enter__ (self ):
98
+ """
99
+ * `__enter__()`: Enters the context manager and enables AMP if it is not already enabled.
100
+ """
53
101
if self ._enabled :
54
102
self .prev = torch .is_autocast_enabled ()
55
103
torch .set_autocast_enabled (self ._enabled )
56
104
torch .autocast_increment_nesting ()
57
105
58
106
def __exit__ (self , * args ):
107
+ """
108
+ * `__exit__()`: Exits the context manager and disables AMP.
109
+ """
59
110
if self ._enabled :
60
111
# Drop the cache when we exit to a nesting level that's outside any instance of autocast
61
112
if torch .autocast_decrement_nesting () == 0 :
@@ -64,6 +115,9 @@ def __exit__(self, *args):
64
115
return False
65
116
66
117
def __call__ (self , func ):
118
+ """
119
+ * `__call__(self, func)`: Returns a decorated function that enables AMP when it is called.
120
+ """
67
121
@functools .wraps (func )
68
122
def decorate_autocast (* args , ** kwargs ):
69
123
with self :
0 commit comments