Skip to content

Commit 220537d

Browse files
Improve Balance class
1 parent 38f2156 commit 220537d

File tree

1 file changed

+67
-32
lines changed

1 file changed

+67
-32
lines changed

bittensor/utils/balance.py

+67-32
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,18 @@
66
from bittensor.core import settings
77

88

9+
def _check_currencies(self, other):
10+
"""Checks that Balance objects have the same netuids to perform arithmetic operations."""
11+
if self.netuid != other.netuid:
12+
warnings.simplefilter("default", DeprecationWarning)
13+
warnings.warn(
14+
"Balance objects must have the same netuid (Alpha currency) to perform arithmetic operations. "
15+
f"First balance is `{self}`. Second balance is `{other}`. ",
16+
category=DeprecationWarning,
17+
stacklevel=2,
18+
)
19+
20+
921
class Balance:
1022
"""
1123
Represents the bittensor balance of the wallet, stored as rao (int).
@@ -23,6 +35,7 @@ class Balance:
2335
rao_unit: str = settings.RAO_SYMBOL # This is the rao unit
2436
rao: int
2537
tao: float
38+
netuid: int = 0
2639

2740
def __init__(self, balance: Union[int, float]):
2841
"""
@@ -78,7 +91,8 @@ def __eq__(self, other: Union[int, float, "Balance"]):
7891
if other is None:
7992
return False
8093

81-
if hasattr(other, "rao"):
94+
if isinstance(other, Balance):
95+
_check_currencies(self, other)
8296
return self.rao == other.rao
8397
else:
8498
try:
@@ -92,7 +106,8 @@ def __ne__(self, other: Union[int, float, "Balance"]):
92106
return not self == other
93107

94108
def __gt__(self, other: Union[int, float, "Balance"]):
95-
if hasattr(other, "rao"):
109+
if isinstance(other, Balance):
110+
_check_currencies(self, other)
96111
return self.rao > other.rao
97112
else:
98113
try:
@@ -103,7 +118,8 @@ def __gt__(self, other: Union[int, float, "Balance"]):
103118
raise NotImplementedError("Unsupported type")
104119

105120
def __lt__(self, other: Union[int, float, "Balance"]):
106-
if hasattr(other, "rao"):
121+
if isinstance(other, Balance):
122+
_check_currencies(self, other)
107123
return self.rao < other.rao
108124
else:
109125
try:
@@ -115,111 +131,129 @@ def __lt__(self, other: Union[int, float, "Balance"]):
115131

116132
def __le__(self, other: Union[int, float, "Balance"]):
117133
try:
134+
if isinstance(other, Balance):
135+
_check_currencies(self, other)
118136
return self < other or self == other
119137
except TypeError:
120138
raise NotImplementedError("Unsupported type")
121139

122140
def __ge__(self, other: Union[int, float, "Balance"]):
123141
try:
142+
if isinstance(other, Balance):
143+
_check_currencies(self, other)
124144
return self > other or self == other
125145
except TypeError:
126146
raise NotImplementedError("Unsupported type")
127147

128148
def __add__(self, other: Union[int, float, "Balance"]):
129-
if hasattr(other, "rao"):
130-
return Balance.from_rao(int(self.rao + other.rao))
149+
if isinstance(other, Balance):
150+
_check_currencies(self, other)
151+
return Balance.from_rao(int(self.rao + other.rao)).set_unit(self.netuid)
131152
else:
132153
try:
133154
# Attempt to cast to int from rao
134-
return Balance.from_rao(int(self.rao + other))
155+
return Balance.from_rao(int(self.rao + other)).set_unit(self.netuid)
135156
except (ValueError, TypeError):
136157
raise NotImplementedError("Unsupported type")
137158

138159
def __radd__(self, other: Union[int, float, "Balance"]):
139160
try:
161+
if isinstance(other, Balance):
162+
_check_currencies(self, other)
140163
return self + other
141164
except TypeError:
142165
raise NotImplementedError("Unsupported type")
143166

144167
def __sub__(self, other: Union[int, float, "Balance"]):
145168
try:
169+
if isinstance(other, Balance):
170+
_check_currencies(self, other)
146171
return self + -other
147172
except TypeError:
148173
raise NotImplementedError("Unsupported type")
149174

150175
def __rsub__(self, other: Union[int, float, "Balance"]):
151176
try:
177+
if isinstance(other, Balance):
178+
_check_currencies(self, other)
152179
return -self + other
153180
except TypeError:
154181
raise NotImplementedError("Unsupported type")
155182

156183
def __mul__(self, other: Union[int, float, "Balance"]):
157-
if hasattr(other, "rao"):
158-
return Balance.from_rao(int(self.rao * other.rao))
184+
if isinstance(other, Balance):
185+
_check_currencies(self, other)
186+
return Balance.from_rao(int(self.rao * other.rao)).set_unit(self.netuid)
159187
else:
160188
try:
161189
# Attempt to cast to int from rao
162-
return Balance.from_rao(int(self.rao * other))
190+
return Balance.from_rao(int(self.rao * other)).set_unit(self.netuid)
163191
except (ValueError, TypeError):
164192
raise NotImplementedError("Unsupported type")
165193

166194
def __rmul__(self, other: Union[int, float, "Balance"]):
195+
if isinstance(other, Balance):
196+
_check_currencies(self, other)
167197
return self * other
168198

169199
def __truediv__(self, other: Union[int, float, "Balance"]):
170-
if hasattr(other, "rao"):
171-
return Balance.from_rao(int(self.rao / other.rao))
200+
if isinstance(other, Balance):
201+
_check_currencies(self, other)
202+
return Balance.from_rao(int(self.rao / other.rao)).set_unit(self.netuid)
172203
else:
173204
try:
174205
# Attempt to cast to int from rao
175-
return Balance.from_rao(int(self.rao / other))
206+
return Balance.from_rao(int(self.rao / other)).set_unit(self.netuid)
176207
except (ValueError, TypeError):
177208
raise NotImplementedError("Unsupported type")
178209

179210
def __rtruediv__(self, other: Union[int, float, "Balance"]):
180-
if hasattr(other, "rao"):
181-
return Balance.from_rao(int(other.rao / self.rao))
211+
if isinstance(other, Balance):
212+
_check_currencies(self, other)
213+
return Balance.from_rao(int(other.rao / self.rao)).set_unit(self.netuid)
182214
else:
183215
try:
184216
# Attempt to cast to int from rao
185-
return Balance.from_rao(int(other / self.rao))
217+
return Balance.from_rao(int(other / self.rao)).set_unit(self.netuid)
186218
except (ValueError, TypeError):
187219
raise NotImplementedError("Unsupported type")
188220

189221
def __floordiv__(self, other: Union[int, float, "Balance"]):
190-
if hasattr(other, "rao"):
191-
return Balance.from_rao(int(self.tao // other.tao))
222+
if isinstance(other, Balance):
223+
_check_currencies(self, other)
224+
return Balance.from_rao(int(self.tao // other.tao)).set_unit(self.netuid)
192225
else:
193226
try:
194227
# Attempt to cast to int from rao
195-
return Balance.from_rao(int(self.rao // other))
228+
return Balance.from_rao(int(self.rao // other)).set_unit(self.netuid)
196229
except (ValueError, TypeError):
197230
raise NotImplementedError("Unsupported type")
198231

199232
def __rfloordiv__(self, other: Union[int, float, "Balance"]):
200-
if hasattr(other, "rao"):
201-
return Balance.from_rao(int(other.rao // self.rao))
233+
if isinstance(other, Balance):
234+
_check_currencies(self, other)
235+
return Balance.from_rao(int(other.rao // self.rao)).set_unit(self.netuid)
202236
else:
203237
try:
204238
# Attempt to cast to int from rao
205-
return Balance.from_rao(int(other // self.rao))
239+
return Balance.from_rao(int(other // self.rao)).set_unit(self.netuid)
206240
except (ValueError, TypeError):
207241
raise NotImplementedError("Unsupported type")
208242

209243
def __nonzero__(self) -> bool:
210244
return bool(self.rao)
211245

212246
def __neg__(self):
213-
return Balance.from_rao(-self.rao)
247+
return Balance.from_rao(-self.rao).set_unit(self.netuid)
214248

215249
def __pos__(self):
216-
return Balance.from_rao(self.rao)
250+
return Balance.from_rao(self.rao).set_unit(self.netuid)
217251

218252
def __abs__(self):
219-
return Balance.from_rao(abs(self.rao))
253+
return Balance.from_rao(abs(self.rao)).set_unit(self.netuid)
220254

221255
@staticmethod
222-
def from_float(amount: float, netuid: int = 0):
256+
def from_float(amount: float, netuid: int = 0) -> "Balance":
223257
"""
224258
Given tao, return :func:`Balance` object with rao(``int``) and tao(``float``), where rao = int(tao*pow(10,9))
225259
Args:
@@ -233,7 +267,7 @@ def from_float(amount: float, netuid: int = 0):
233267
return Balance(rao_).set_unit(netuid)
234268

235269
@staticmethod
236-
def from_tao(amount: float, netuid: int = 0):
270+
def from_tao(amount: float, netuid: int = 0) -> "Balance":
237271
"""
238272
Given tao, return Balance object with rao(``int``) and tao(``float``), where rao = int(tao*pow(10,9))
239273
@@ -248,7 +282,7 @@ def from_tao(amount: float, netuid: int = 0):
248282
return Balance(rao_).set_unit(netuid)
249283

250284
@staticmethod
251-
def from_rao(amount: int, netuid: int = 0):
285+
def from_rao(amount: int, netuid: int = 0) -> "Balance":
252286
"""
253287
Given rao, return Balance object with rao(``int``) and tao(``float``), where rao = int(tao*pow(10,9))
254288
@@ -262,7 +296,7 @@ def from_rao(amount: int, netuid: int = 0):
262296
return Balance(amount).set_unit(netuid)
263297

264298
@staticmethod
265-
def get_unit(netuid: int):
299+
def get_unit(netuid: int) -> str:
266300
base = len(units)
267301
if netuid < base:
268302
return units[netuid]
@@ -274,6 +308,7 @@ def get_unit(netuid: int):
274308
return result
275309

276310
def set_unit(self, netuid: int):
311+
self.netuid = netuid
277312
self.unit = Balance.get_unit(netuid)
278313
self.rao_unit = Balance.get_unit(netuid)
279314
return self
@@ -777,18 +812,18 @@ def fixed_to_float(
777812
]
778813

779814

780-
def tao(amount: float) -> Balance:
815+
def tao(amount: float, netuid: int = 0) -> Balance:
781816
"""
782817
Helper function to create a Balance object from a float (Tao)
783818
"""
784-
return Balance.from_tao(amount)
819+
return Balance.from_tao(amount).set_unit(netuid)
785820

786821

787-
def rao(amount: int) -> Balance:
822+
def rao(amount: int, netuid: int = 0) -> Balance:
788823
"""
789824
Helper function to create a Balance object from an int (Rao)
790825
"""
791-
return Balance.from_rao(amount)
826+
return Balance.from_rao(amount).set_unit(netuid)
792827

793828

794829
def check_and_convert_to_balance(

0 commit comments

Comments
 (0)