diff --git a/bindings/python/py_src/safetensors/torch.py b/bindings/python/py_src/safetensors/torch.py index 48532ea5..a8a47d76 100644 --- a/bindings/python/py_src/safetensors/torch.py +++ b/bindings/python/py_src/safetensors/torch.py @@ -256,6 +256,7 @@ def save_file( tensors: Dict[str, torch.Tensor], filename: Union[str, os.PathLike], metadata: Optional[Dict[str, str]] = None, + force_contiguous: Optional[bool] = True, ): """ Saves a dictionary of tensors into raw bytes in safetensors format. @@ -269,6 +270,11 @@ def save_file( Optional text only metadata you might want to save in your header. For instance it can be useful to specify more about the underlying tensors. This is purely informative and does not affect tensor loading. + force_contiguous (`boolean`, *optional*, defaults to True): + Forcing the state_dict to be saved as contiguous tensors. + This has no effect on the correctness of the model, but it + could potentially change performance if the layout of the tensor + was chosen specifically for that reason. Returns: `None` @@ -283,6 +289,8 @@ def save_file( save_file(tensors, "model.safetensors") ``` """ + if force_contiguous: + tensors = {k: v.contiguous() for k, v in tensors.items()} serialize_file(_flatten(tensors), filename, metadata=metadata)