Skip to content

Commit e4dad99

Browse files
leopfgeohotchenyuxyz
authored
nn.state docs cleanup (tinygrad#8332)
* doc cleanup * extension cleanup * manual definition * bring back accept_filename for gguf_load --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com> Co-authored-by: chenyu <chenyu@fastmail.com>
1 parent 1ea4876 commit e4dad99

File tree

2 files changed

+27
-11
lines changed

2 files changed

+27
-11
lines changed

docs/nn.md

+7
Original file line numberDiff line numberDiff line change
@@ -29,5 +29,12 @@
2929
::: tinygrad.nn.state.get_state_dict
3030
::: tinygrad.nn.state.get_parameters
3131
::: tinygrad.nn.state.load_state_dict
32+
::: tinygrad.nn.state.tar_extract
33+
options:
34+
show_signature: false
35+
separate_signature: false
3236
::: tinygrad.nn.state.torch_load
37+
options:
38+
show_signature: false
39+
separate_signature: false
3340
::: tinygrad.nn.state.gguf_load

tinygrad/nn/state.py

+20-11
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,14 @@ def wrapper(fn: Union[Tensor, str, pathlib.Path]) -> T: return func(Tensor(pathl
4343
@accept_filename
4444
def safe_load_metadata(t:Tensor) -> tuple[Tensor, int, dict[str, Any]]:
4545
"""
46-
Loads a .safetensor file from disk, returning the data, metadata length, and metadata.
46+
Loads a .safetensor file, returning the source tensor, data start position, and metadata.
4747
"""
4848
data_start = int.from_bytes(t[0:8].data(), "little") + 8
4949
return t, data_start, json.loads(t[8:data_start].data().tobytes())
5050

5151
def safe_load(fn:Union[Tensor, str, pathlib.Path]) -> dict[str, Tensor]:
5252
"""
53-
Loads a .safetensor file from disk, returning the state_dict.
53+
Loads a .safetensor file, returning the `state_dict`.
5454
5555
```python
5656
state_dict = nn.state.safe_load("test.safetensor")
@@ -63,7 +63,7 @@ def safe_load(fn:Union[Tensor, str, pathlib.Path]) -> dict[str, Tensor]:
6363

6464
def safe_save(tensors:dict[str, Tensor], fn:str, metadata:Optional[dict[str, Any]]=None):
6565
"""
66-
Saves a state_dict to disk in a .safetensor file with optional metadata.
66+
Saves a `state_dict` to disk in a .safetensor file with optional metadata.
6767
6868
```python
6969
t = Tensor([1, 2, 3])
@@ -87,7 +87,7 @@ def safe_save(tensors:dict[str, Tensor], fn:str, metadata:Optional[dict[str, Any
8787

8888
def get_state_dict(obj, prefix:str='', tensor_type=Tensor) -> dict[str, Tensor]:
8989
"""
90-
Returns a state_dict of the object, with optional prefix.
90+
Returns a `state_dict` of the object, with optional prefix.
9191
9292
```python exec="true" source="above" session="tensor" result="python"
9393
class Net:
@@ -126,7 +126,7 @@ def __init__(self):
126126

127127
def load_state_dict(model, state_dict:dict[str, Tensor], strict=True, verbose=True, consume=False, realize=True) -> None:
128128
"""
129-
Loads a state_dict into a model.
129+
Loads a `state_dict` into a model.
130130
131131
```python
132132
class Net:
@@ -162,7 +162,11 @@ def __init__(self):
162162
@accept_filename
163163
def tar_extract(t: Tensor) -> dict[str, Tensor]:
164164
"""
165-
Extracts files from a tar archive and returns them as dictionary of names (keys) and tensors (values).
165+
```python
166+
tar_extract(fn: Tensor | str | Path) -> dict[str, Tensor]
167+
```
168+
169+
Extracts files from a tar archive and returns them as a dictionary of names (keys) and tensors (values).
166170
167171
```python
168172
tensors = nn.state.tar_extract(Tensor(pathlib.Path("archive.tar")))
@@ -176,7 +180,11 @@ def tar_extract(t: Tensor) -> dict[str, Tensor]:
176180
@accept_filename
177181
def torch_load(t:Tensor) -> dict[str, Tensor]:
178182
"""
179-
Loads a torch .pth file from disk.
183+
```python
184+
torch_load(fn: Tensor | str | Path) -> dict[str, Tensor]
185+
```
186+
187+
Loads a torch .pth file, returning the `state_dict`.
180188
181189
```python
182190
state_dict = nn.state.torch_load("test.pth")
@@ -294,13 +302,14 @@ def q_to_uint8(t: Tensor, b: int) -> Tensor:
294302
@accept_filename
295303
def gguf_load(tensor: Tensor) -> tuple[dict, dict[str, Tensor]]:
296304
"""
297-
Loads a gguf file from a tensor.
305+
Loads a .gguf file, returning the `kv_data` and `state_dict`.
298306
299307
```python
300-
fn = "Meta-Llama-3-8B-Instruct.Q4_0.gguf"
301-
gguf_tensor = Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}").to(Device.DEFAULT)
302-
kv_data, state_dict = gguf_load(gguf_tensor)
308+
gguf_tensor = Tensor(pathlib.Path("Meta-Llama-3-8B-Instruct.Q4_0.gguf")).to(Device.DEFAULT)
309+
kv_data, state_dict = nn.state.gguf_load(gguf_tensor)
303310
```
311+
312+
NOTE: The provided tensor must be on a device that supports execution.
304313
"""
305314
reader, kv_data, state_dict = io.BufferedReader(TensorIO(tensor), 1_000_000), {}, {}
306315
def read_unpack(fmt: str, n: int): return struct.unpack(fmt, reader.read(n))[0]

0 commit comments

Comments
 (0)