Skip to content

Commit

Permalink
Fixing format.
Browse files Browse the repository at this point in the history
  • Loading branch information
Narsil committed Jul 26, 2024
1 parent 146fd4b commit e4811ac
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
5 changes: 4 additions & 1 deletion bindings/python/py_src/safetensors/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,9 @@ def save_model(
raise ValueError(msg)


def load_model(model: torch.nn.Module, filename: Union[str, os.PathLike], strict: bool = True, device: Union[str, int] = "cpu") -> Tuple[List[str], List[str]]:
def load_model(
model: torch.nn.Module, filename: Union[str, os.PathLike], strict: bool = True, device: Union[str, int] = "cpu"
) -> Tuple[List[str], List[str]]:
"""
Loads a given filename onto a torch model.
This method exists specifically to avoid tensor sharing issues which are
Expand Down Expand Up @@ -340,6 +342,7 @@ def load(data: bytes) -> Dict[str, torch.Tensor]:
flat = deserialize(data)
return _view2torch(flat)


# torch.float8 formats require 2.1; we do not support these dtypes on earlier versions
_float8_e4m3fn = getattr(torch, "float8_e4m3fn", None)
_float8_e5m2 = getattr(torch, "float8_e5m2", None)
Expand Down
2 changes: 1 addition & 1 deletion bindings/python/tests/test_pt_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def test_odd_dtype(self):

def test_odd_dtype_fp8(self):
if torch.__version__ < "2.1":
return # torch.float8 requires 2.1
return # torch.float8 requires 2.1

data = {
"test1": torch.tensor([-0.5], dtype=torch.float8_e4m3fn),
Expand Down

0 comments on commit e4811ac

Please sign in to comment.