We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 741bbc9 commit 07d3676Copy full SHA for 07d3676
test/unit/test_disk_tensor.py
@@ -10,7 +10,7 @@ def compare_weights_both(url):
10
import torch
11
fn = fetch(url)
12
tg_weights = get_state_dict(torch_load(fn))
13
- torch_weights = get_state_dict(torch.load(fn, map_location=torch.device('cpu'), weights_only=True), tensor_type=torch.Tensor)
+ torch_weights = get_state_dict(torch.load(fn, map_location=torch.device('cpu'), weights_only=False), tensor_type=torch.Tensor)
14
assert list(tg_weights.keys()) == list(torch_weights.keys())
15
for k in tg_weights:
16
if tg_weights[k].dtype == dtypes.bfloat16: tg_weights[k] = torch_weights[k].float() # numpy doesn't support bfloat16
0 commit comments