Skip to content

Commit

Permalink
Update 2 locations.
Browse files Browse the repository at this point in the history
  • Loading branch information
Narsil committed Jul 26, 2024
1 parent e4811ac commit 5eaf8a1
Showing 1 changed file with 23 additions and 7 deletions.
30 changes: 23 additions & 7 deletions bindings/python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -930,13 +930,29 @@ impl PySafeSlice {
if byteorder == "big" {
let inplace_kwargs =
[(intern!(py, "inplace"), false.into_py(py))].into_py_dict_bound(py);

let numpy = tensor
.getattr(intern!(py, "numpy"))?
.call0()?
.getattr("byteswap")?
.call((), Some(&inplace_kwargs))?;
tensor = torch.getattr(intern!(py, "from_numpy"))?.call1((numpy,))?;
let dtype: PyObject = get_pydtype(torch, self.info.dtype, false)?;

let version: String = torch.getattr(intern!(py, "__version__"))?.extract()?;
let version =
Version::from_string(&version).map_err(SafetensorError::new_err)?;
if version >= Version::new(2, 1, 0) {
tensor
.getattr(intern!(py, "untyped_storage"))?
.call0()?
.getattr(intern!(py, "byteswap"))?
.call1((dtype,))?;
} else if self.info.dtype == Dtype::BF16 {
return Err(SafetensorError::new_err(
"PyTorch 2.1 or later is required for big-endian machine",
));
} else {
let numpy = tensor
.getattr(intern!(py, "numpy"))?
.call0()?
.getattr("byteswap")?
.call((), Some(&inplace_kwargs))?;
tensor = torch.getattr(intern!(py, "from_numpy"))?.call1((numpy,))?;
}
}
tensor = tensor
.getattr(intern!(py, "reshape"))?
Expand Down

0 comments on commit 5eaf8a1

Please sign in to comment.