Skip to content

Commit 3ae66e5

Browse files
authored
least_upper_float is at least default_float (tinygrad#9303)
* least_upper_float is at least default_float en route for div rounding mode. dtype of true int division would change from int32 to default_float, which matches torch too. * fix bert acc
1 parent 3210b65 commit 3ae66e5

File tree

3 files changed

+17
-6
lines changed

3 files changed

+17
-6
lines changed

extra/models/bert.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def accuracy(self, prediction_logits:Tensor, seq_relationship_logits:Tensor, mas
7272
next_sentence_loss = seq_relationship_logits.binary_crossentropy_logits(next_sentence_labels)
7373

7474
# TODO: is it okay that next_sentence_loss is half here?
75-
return masked_lm_correct.sum() / valid.sum(), seq_relationship_correct.mean(), masked_lm_loss, next_sentence_loss.float()
75+
return masked_lm_correct.sum().float() / valid.sum(), seq_relationship_correct.mean(), masked_lm_loss, next_sentence_loss.float()
7676

7777
def load_from_pretrained(self, tf_weight_path:str=Path(__file__).parent.parent / "datasets" / "wiki"):
7878
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Mute tf flag info

test/test_dtype.py

+15-4
Original file line numberDiff line numberDiff line change
@@ -633,16 +633,22 @@ def test_dtype_promo(self):
633633
assert least_upper_dtype(dtypes.float16, dtypes.int64) == dtypes.float16
634634
assert least_upper_dtype(dtypes.float16, dtypes.uint64) == dtypes.float16
635635

636-
@given(strat.sampled_from(dtype_floats))
637-
def test_float_to_float(self, dt):
638-
assert least_upper_float(dt) == dt
639-
640636
class TestAutoCastType(unittest.TestCase):
641637
def setUp(self):
642638
self.old_default_int, self.old_default_float = dtypes.default_int, dtypes.default_float
643639
def tearDown(self):
644640
dtypes.default_int, dtypes.default_float = self.old_default_int, self.old_default_float
645641

642+
@given(strat.sampled_from(dtype_floats), strat.sampled_from(dtype_floats))
643+
def test_least_upper_float_input_is_float(self, input_dtype, default_float):
644+
dtypes.default_float = default_float
645+
self.assertEqual(least_upper_float(input_dtype), input_dtype)
646+
647+
@given(strat.sampled_from(dtype_ints), strat.sampled_from(dtype_floats))
648+
def test_least_upper_float_input_is_int(self, input_dtype, default_float):
649+
dtypes.default_float = default_float
650+
self.assertEqual(least_upper_float(input_dtype), default_float)
651+
646652
@given(strat.sampled_from([d for d in core_dtypes if dtypes.is_int(d) and is_dtype_supported(d)]))
647653
def test_int_to_float_unary_func(self, dtype):
648654
for func in [
@@ -667,6 +673,11 @@ def test_broadcast_scalar(self, dt):
667673
assert (Tensor.ones(4, 4, dtype=dt) + 2).dtype == (dt if dtypes.is_float(dt) or dtypes.is_int(dt) else dtypes.default_int)
668674
assert (Tensor.ones(4, 4, dtype=dt) + True).dtype == dt
669675

676+
@given(strat.sampled_from(dtype_floats))
677+
def test_int_div_int(self, default_float):
678+
dtypes.default_float = default_float
679+
self.assertEqual(Tensor([1]).div(Tensor([2])).dtype, default_float)
680+
670681
def test_sum(self):
671682
assert (Tensor([0, 1], dtype=dtypes.bool)).sum().dtype == dtypes.int32
672683
assert (Tensor([0, 1], dtype=dtypes.int8)).sum().dtype == dtypes.int32

tinygrad/dtype.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def _get_recursive_parents(dtype:DType) -> set[DType]:
169169
@functools.lru_cache(None)
170170
def least_upper_dtype(*ds:DType) -> DType:
171171
return min(set.intersection(*[_get_recursive_parents(d) for d in ds])) if not (images:=[d for d in ds if isinstance(d, ImageDType)]) else images[0]
172-
def least_upper_float(dt:DType) -> DType: return dt if dtypes.is_float(dt) else least_upper_dtype(dt, dtypes.float32)
172+
def least_upper_float(dt:DType) -> DType: return dt if dtypes.is_float(dt) else least_upper_dtype(dt, dtypes.default_float)
173173

174174
DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if isinstance(v, DType) and not k.startswith(("default", "void"))}
175175
INVERSE_DTYPES_DICT = {**{v.name:k for k,v in DTYPES_DICT.items()}, "void": "void"}

0 commit comments

Comments
 (0)