Skip to content

Commit 189f62d

Browse files
authored
add rounding to tqdm unit scale (tinygrad#9507)
fixed `AssertionError: ' 1.00/10.0 1000it/s]' != ' 1.00/10.0 1.00kit/s]'`
1 parent a5c971f commit 189f62d

File tree

2 files changed

+38
-2
lines changed

2 files changed

+38
-2
lines changed

test/unit/test_tqdm.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,41 @@ def test_unit_scale(self, mock_terminal_size, mock_stderr):
9292
self._compare_bars(tinytqdm_output, tqdm_output)
9393
if n > 3: break
9494

95+
@patch('sys.stderr', new_callable=StringIO)
96+
@patch('shutil.get_terminal_size')
97+
def test_unit_scale_exact(self, mock_terminal_size, mock_stderr):
98+
unit_scale = True
99+
ncols = 80
100+
mock_terminal_size.return_value = namedtuple(field_names='columns', typename='terminal_size')(ncols)
101+
mock_stderr.truncate(0)
102+
103+
total = 10
104+
with patch('time.perf_counter', side_effect=[0]+list(range(100))): # one more 0 for the init call
105+
# compare bars at each iteration (only when tinytqdm bar has been updated)
106+
for n in tinytqdm(range(total), desc="Test", total=total, unit_scale=unit_scale, rate=1e9):
107+
tinytqdm_output = mock_stderr.getvalue().split("\r")[-1].rstrip()
108+
elapsed = n
109+
tqdm_output = tqdm.format_meter(n=n, total=total, elapsed=elapsed, ncols=ncols, prefix="Test", unit_scale=unit_scale)
110+
self._compare_bars(tinytqdm_output, tqdm_output)
111+
if n > 5: break
112+
113+
total = 10
114+
k=0.001000001
115+
# regression test for
116+
# E AssertionError: ' 1.00/10.0 1000it/s]' != ' 1.00/10.0 1.00kit/s]'
117+
# E - 1.00/10.0 1000it/s]
118+
# E ? ^
119+
# E + 1.00/10.0 1.00kit/s]
120+
# E ? + ^
121+
with patch('time.perf_counter', side_effect=[0, *[i*k for i in range(100)]]): # one more 0 for the init call
122+
# compare bars at each iteration (only when tinytqdm bar has been updated)
123+
for n in tinytqdm(range(total), desc="Test", total=total, unit_scale=unit_scale, rate=1e9):
124+
tinytqdm_output = mock_stderr.getvalue().split("\r")[-1].rstrip()
125+
elapsed = n*k
126+
tqdm_output = tqdm.format_meter(n=n, total=total, elapsed=elapsed, ncols=ncols, prefix="Test", unit_scale=unit_scale)
127+
self._compare_bars(tinytqdm_output, tqdm_output)
128+
if n > 5: break
129+
95130
@patch('sys.stderr', new_callable=StringIO)
96131
@patch('shutil.get_terminal_size')
97132
def test_set_description(self, mock_terminal_size, mock_stderr):

tinygrad/helpers.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -322,9 +322,10 @@ def update(self, n:int=0, close:bool=False):
322322
self.n, self.i = self.n+n, self.i+1
323323
if self.disable or (not close and self.i % self.skip != 0): return
324324
prog, elapsed, ncols = self.n/self.t if self.t else 0, time.perf_counter()-self.st, shutil.get_terminal_size().columns
325-
if self.i/elapsed > self.rate and self.i: self.skip = max(int(self.i/elapsed)//self.rate,1)
325+
if elapsed and self.i/elapsed > self.rate and self.i: self.skip = max(int(self.i/elapsed)//self.rate,1)
326326
def HMS(t): return ':'.join(f'{x:02d}' if i else str(x) for i,x in enumerate([int(t)//3600,int(t)%3600//60,int(t)%60]) if i or x)
327-
def SI(x): return (f"{x/1000**int(g:=math.log(x,1000)):.{int(3-3*math.fmod(g,1))}f}"[:4].rstrip('.')+' kMGTPEZY'[int(g)].strip()) if x else '0.00'
327+
def SI(x):
328+
return (f"{x/1000**int(g:=round(math.log(x,1000),6)):.{int(3-3*math.fmod(g,1))}f}"[:4].rstrip('.')+' kMGTPEZY'[int(g)].strip()) if x else '0.00'
328329
prog_text = f'{SI(self.n)}{f"/{SI(self.t)}" if self.t else self.unit}' if self.unit_scale else f'{self.n}{f"/{self.t}" if self.t else self.unit}'
329330
est_text = f'<{HMS(elapsed/prog-elapsed) if self.n else "?"}' if self.t else ''
330331
it_text = (SI(self.n/elapsed) if self.unit_scale else f"{self.n/elapsed:5.2f}") if self.n else "?"

0 commit comments

Comments
 (0)