Skip to content

Commit

Permalink
Hopefully fixing the byteswapping.
Browse files Browse the repository at this point in the history
  • Loading branch information
Narsil committed Jul 30, 2024
1 parent 6556258 commit dd96068
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 7 deletions.
6 changes: 3 additions & 3 deletions bindings/python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -928,14 +928,15 @@ impl PySafeSlice {
.call((storage_slice,), Some(&kwargs))?
.getattr(intern!(py, "view"))?
.call((), Some(&view_kwargs))?;
println!("Byte order {byteorder}");
if byteorder == "big" {
println!("Using torch byteswap");
let version: String = torch.getattr(intern!(py, "__version__"))?.extract()?;
let version =
Version::from_string(&version).map_err(SafetensorError::new_err)?;
if version >= Version::new(2, 1, 0) {
let dtype: PyObject = get_pydtype(torch, self.info.dtype, false)?;
// Clone is required otherwise storage is shared with previous slices,
// making n amount of byteswaps.
tensor = tensor.getattr(intern!(py, "clone"))?.call0()?;
tensor
.getattr(intern!(py, "untyped_storage"))?
.call0()?
Expand All @@ -946,7 +947,6 @@ impl PySafeSlice {
"PyTorch 2.1 or later is required for big-endian machine and bfloat16 support.",
));
} else {
println!("Using numpy byteswap");
let inplace_kwargs =
[(intern!(py, "inplace"), false.into_py(py))].into_py_dict_bound(py);
let numpy = tensor
Expand Down
4 changes: 2 additions & 2 deletions bindings/python/tests/test_pt_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,11 @@ def test_serialization(self):
)

data = torch.ones((2, 2), dtype=torch.bfloat16)
data[0, 0] = 2.25
out = save({"test": data})

self.assertEqual(
out,
b'@\x00\x00\x00\x00\x00\x00\x00{"test":{"dtype":"BF16","shape":[2,2],"data_offsets":[0,8]}} \x80?\x80?\x80?\x80?',
b'@\x00\x00\x00\x00\x00\x00\x00{"test":{"dtype":"BF16","shape":[2,2],"data_offsets":[0,8]}} \x10@\x80?\x80?\x80?',
)

def test_odd_dtype(self):
Expand Down
2 changes: 0 additions & 2 deletions bindings/python/tests/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,6 @@ def test_torch_slice(self):

tensor = slice_[:2]
self.assertEqual(list(tensor.shape), [2, 5])
if not torch.allclose(tensor, A[:2]):
print(f"{tensor} != {A[:2]}")
torch.testing.assert_close(tensor, A[:2])

tensor = slice_[:, :2]
Expand Down

0 comments on commit dd96068

Please sign in to comment.