From 655625827e4ae506453fb46aa6922295b3f9e3f9 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 30 Jul 2024 09:16:24 +0200 Subject: [PATCH] Linter. --- bindings/python/tests/test_pt_comparison.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/bindings/python/tests/test_pt_comparison.py b/bindings/python/tests/test_pt_comparison.py index 6f674b8f..d0f052e5 100644 --- a/bindings/python/tests/test_pt_comparison.py +++ b/bindings/python/tests/test_pt_comparison.py @@ -55,7 +55,7 @@ def test_serialization(self): self.assertEqual( out, - b'@\x00\x00\x00\x00\x00\x00\x00{"test":{"dtype":"BF16","shape":[2,2],"data_offsets":[0,8]}} \x80?\x80?\x80?\x80?' + b'@\x00\x00\x00\x00\x00\x00\x00{"test":{"dtype":"BF16","shape":[2,2],"data_offsets":[0,8]}} \x80?\x80?\x80?\x80?', ) def test_odd_dtype(self): @@ -87,10 +87,10 @@ def test_odd_dtype_fp8(self): save_file(data, local) reloaded = load_file(local) # note: PyTorch doesn't implement torch.equal for float8 so we just compare the single element - self.assertEqual(data["test1"].dtype, torch.float8_e4m3fn) - self.assertEqual(data["test1"].item(), -0.5) - self.assertEqual(data["test2"].dtype, torch.float8_e5m2) - self.assertEqual(data["test2"].item(), -0.5) + self.assertEqual(reloaded["test1"].dtype, torch.float8_e4m3fn) + self.assertEqual(reloaded["test1"].item(), -0.5) + self.assertEqual(reloaded["test2"].dtype, torch.float8_e5m2) + self.assertEqual(reloaded["test2"].item(), -0.5) def test_zero_sized(self): data = {