-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathfileio_utils.py
24 lines (20 loc) · 919 Bytes
/
fileio_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch
import numpy as np
def save_int(t: torch.Tensor, scaling_factor: int, path):
if path[-4:] != '.bin':
raise ValueError('Path must end with .bin')
t_ = torch.round(t * scaling_factor).to(torch.int32)
t_.cpu().detach().numpy().astype(np.int32).tofile(path)
def save_long(t: torch.Tensor, scaling_factor: int, path):
if path[-4:] != '.bin':
raise ValueError('Path must end with .bin')
t_ = torch.round(t * scaling_factor).to(torch.int64)
t_.cpu().detach().numpy().astype(np.int64).tofile(path)
def load_int(path, device = 0):
if path[-4:] != '.bin':
raise ValueError('Path must end with .bin')
return torch.from_numpy(np.fromfile(path, dtype=np.int32)).to(device)
def load_long(path, device = 0):
if path[-4:] != '.bin':
raise ValueError('Path must end with .bin')
return torch.from_numpy(np.fromfile(path, dtype=np.int64)).to(device)