@@ -633,16 +633,22 @@ def test_dtype_promo(self):
633
633
assert least_upper_dtype (dtypes .float16 , dtypes .int64 ) == dtypes .float16
634
634
assert least_upper_dtype (dtypes .float16 , dtypes .uint64 ) == dtypes .float16
635
635
636
- @given (strat .sampled_from (dtype_floats ))
637
- def test_float_to_float (self , dt ):
638
- assert least_upper_float (dt ) == dt
639
-
640
636
class TestAutoCastType (unittest .TestCase ):
641
637
def setUp (self ):
642
638
self .old_default_int , self .old_default_float = dtypes .default_int , dtypes .default_float
643
639
def tearDown (self ):
644
640
dtypes .default_int , dtypes .default_float = self .old_default_int , self .old_default_float
645
641
642
+ @given (strat .sampled_from (dtype_floats ), strat .sampled_from (dtype_floats ))
643
+ def test_least_upper_float_input_is_float (self , input_dtype , default_float ):
644
+ dtypes .default_float = default_float
645
+ self .assertEqual (least_upper_float (input_dtype ), input_dtype )
646
+
647
+ @given (strat .sampled_from (dtype_ints ), strat .sampled_from (dtype_floats ))
648
+ def test_least_upper_float_input_is_int (self , input_dtype , default_float ):
649
+ dtypes .default_float = default_float
650
+ self .assertEqual (least_upper_float (input_dtype ), default_float )
651
+
646
652
@given (strat .sampled_from ([d for d in core_dtypes if dtypes .is_int (d ) and is_dtype_supported (d )]))
647
653
def test_int_to_float_unary_func (self , dtype ):
648
654
for func in [
@@ -667,6 +673,11 @@ def test_broadcast_scalar(self, dt):
667
673
assert (Tensor .ones (4 , 4 , dtype = dt ) + 2 ).dtype == (dt if dtypes .is_float (dt ) or dtypes .is_int (dt ) else dtypes .default_int )
668
674
assert (Tensor .ones (4 , 4 , dtype = dt ) + True ).dtype == dt
669
675
676
+ @given (strat .sampled_from (dtype_floats ))
677
+ def test_int_div_int (self , default_float ):
678
+ dtypes .default_float = default_float
679
+ self .assertEqual (Tensor ([1 ]).div (Tensor ([2 ])).dtype , default_float )
680
+
670
681
def test_sum (self ):
671
682
assert (Tensor ([0 , 1 ], dtype = dtypes .bool )).sum ().dtype == dtypes .int32
672
683
assert (Tensor ([0 , 1 ], dtype = dtypes .int8 )).sum ().dtype == dtypes .int32
0 commit comments