Skip to content

Commit 31de70c

Browse files
acolombbizfsc
andauthored
Modernize type annotations and fix some discrepancies (#451)
* Use postponed evaluation of type annotations (PEP-563 / PEP-649 style) instead of stringized type annotations. * Clarify that each SDO object has its corresponding OD entry linked in the .od attribute, not the whole ObjectDictionary. * __iter__() should return Iterator like in Mapping, not Iterable. * Annotate PdoVariable.pdo_parent type. * Conditionally import types from other sub-packages to avoid dependency cycles. Co-authored-by: Frieder Schüler <frieder.schueler@bizerba.com>
1 parent fc577d1 commit 31de70c

File tree

4 files changed

+64
-52
lines changed

4 files changed

+64
-52
lines changed

canopen/network.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
from __future__ import annotations
2+
13
from collections.abc import MutableMapping
24
import logging
35
import threading
4-
from typing import Callable, Dict, Iterable, List, Optional, Union
6+
from typing import Callable, Dict, Iterator, List, Optional, Union
57

68
try:
79
import can
@@ -82,7 +84,7 @@ def unsubscribe(self, can_id, callback=None) -> None:
8284
else:
8385
self.subscribers[can_id].remove(callback)
8486

85-
def connect(self, *args, **kwargs) -> "Network":
87+
def connect(self, *args, **kwargs) -> Network:
8688
"""Connect to CAN bus using python-can.
8789
8890
Arguments are passed directly to :class:`can.BusABC`. Typically these
@@ -214,7 +216,7 @@ def send_message(self, can_id: int, data: bytes, remote: bool = False) -> None:
214216

215217
def send_periodic(
216218
self, can_id: int, data: bytes, period: float, remote: bool = False
217-
) -> "PeriodicMessageTask":
219+
) -> PeriodicMessageTask:
218220
"""Start sending a message periodically.
219221
220222
:param can_id:
@@ -277,7 +279,7 @@ def __delitem__(self, node_id: int):
277279
self.nodes[node_id].remove_network()
278280
del self.nodes[node_id]
279281

280-
def __iter__(self) -> Iterable[int]:
282+
def __iter__(self) -> Iterator[int]:
281283
return iter(self.nodes)
282284

283285
def __len__(self) -> int:

canopen/objectdictionary/__init__.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
"""
22
Object Dictionary module
33
"""
4+
from __future__ import annotations
5+
46
import struct
5-
from typing import Dict, Iterable, List, Optional, TextIO, Union
7+
from typing import Dict, Iterator, List, Optional, TextIO, Union
68
from collections.abc import MutableMapping, Mapping
79
import logging
810

@@ -13,7 +15,7 @@
1315
logger = logging.getLogger(__name__)
1416

1517

16-
def export_od(od, dest:Union[str,TextIO,None]=None, doc_type:Optional[str]=None):
18+
def export_od(od, dest: Union[str, TextIO, None] = None, doc_type: Optional[str] = None):
1719
""" Export :class: ObjectDictionary to a file.
1820
1921
:param od:
@@ -55,7 +57,7 @@ def export_od(od, dest:Union[str,TextIO,None]=None, doc_type:Optional[str]=None)
5557
def import_od(
5658
source: Union[str, TextIO, None],
5759
node_id: Optional[int] = None,
58-
) -> "ObjectDictionary":
60+
) -> ObjectDictionary:
5961
"""Parse an EDS, DCF, or EPF file.
6062
6163
:param source:
@@ -102,7 +104,7 @@ def __init__(self):
102104

103105
def __getitem__(
104106
self, index: Union[int, str]
105-
) -> Union["ODArray", "ODRecord", "ODVariable"]:
107+
) -> Union[ODArray, ODRecord, ODVariable]:
106108
"""Get object from object dictionary by name or index."""
107109
item = self.names.get(index) or self.indices.get(index)
108110
if item is None:
@@ -113,7 +115,7 @@ def __getitem__(
113115
return item
114116

115117
def __setitem__(
116-
self, index: Union[int, str], obj: Union["ODArray", "ODRecord", "ODVariable"]
118+
self, index: Union[int, str], obj: Union[ODArray, ODRecord, ODVariable]
117119
):
118120
assert index == obj.index or index == obj.name
119121
self.add_object(obj)
@@ -123,7 +125,7 @@ def __delitem__(self, index: Union[int, str]):
123125
del self.indices[obj.index]
124126
del self.names[obj.name]
125127

126-
def __iter__(self) -> Iterable[int]:
128+
def __iter__(self) -> Iterator[int]:
127129
return iter(sorted(self.indices))
128130

129131
def __len__(self) -> int:
@@ -132,7 +134,7 @@ def __len__(self) -> int:
132134
def __contains__(self, index: Union[int, str]):
133135
return index in self.names or index in self.indices
134136

135-
def add_object(self, obj: Union["ODArray", "ODRecord", "ODVariable"]) -> None:
137+
def add_object(self, obj: Union[ODArray, ODRecord, ODVariable]) -> None:
136138
"""Add object to the object dictionary.
137139
138140
:param obj:
@@ -147,7 +149,7 @@ def add_object(self, obj: Union["ODArray", "ODRecord", "ODVariable"]) -> None:
147149

148150
def get_variable(
149151
self, index: Union[int, str], subindex: int = 0
150-
) -> Optional["ODVariable"]:
152+
) -> Optional[ODVariable]:
151153
"""Get the variable object at specified index (and subindex if applicable).
152154
153155
:return: ODVariable if found, else `None`
@@ -182,13 +184,13 @@ def __init__(self, name: str, index: int):
182184
def __repr__(self) -> str:
183185
return f"<{type(self).__qualname__} {self.name!r} at {pretty_index(self.index)}>"
184186

185-
def __getitem__(self, subindex: Union[int, str]) -> "ODVariable":
187+
def __getitem__(self, subindex: Union[int, str]) -> ODVariable:
186188
item = self.names.get(subindex) or self.subindices.get(subindex)
187189
if item is None:
188190
raise KeyError(f"Subindex {pretty_index(None, subindex)} was not found")
189191
return item
190192

191-
def __setitem__(self, subindex: Union[int, str], var: "ODVariable"):
193+
def __setitem__(self, subindex: Union[int, str], var: ODVariable):
192194
assert subindex == var.subindex
193195
self.add_member(var)
194196

@@ -200,16 +202,16 @@ def __delitem__(self, subindex: Union[int, str]):
200202
def __len__(self) -> int:
201203
return len(self.subindices)
202204

203-
def __iter__(self) -> Iterable[int]:
205+
def __iter__(self) -> Iterator[int]:
204206
return iter(sorted(self.subindices))
205207

206208
def __contains__(self, subindex: Union[int, str]) -> bool:
207209
return subindex in self.names or subindex in self.subindices
208210

209-
def __eq__(self, other: "ODRecord") -> bool:
211+
def __eq__(self, other: ODRecord) -> bool:
210212
return self.index == other.index
211213

212-
def add_member(self, variable: "ODVariable") -> None:
214+
def add_member(self, variable: ODVariable) -> None:
213215
"""Adds a :class:`~canopen.objectdictionary.ODVariable` to the record."""
214216
variable.parent = self
215217
self.subindices[variable.subindex] = variable
@@ -241,7 +243,7 @@ def __init__(self, name: str, index: int):
241243
def __repr__(self) -> str:
242244
return f"<{type(self).__qualname__} {self.name!r} at {pretty_index(self.index)}>"
243245

244-
def __getitem__(self, subindex: Union[int, str]) -> "ODVariable":
246+
def __getitem__(self, subindex: Union[int, str]) -> ODVariable:
245247
var = self.names.get(subindex) or self.subindices.get(subindex)
246248
if var is not None:
247249
# This subindex is defined
@@ -264,13 +266,13 @@ def __getitem__(self, subindex: Union[int, str]) -> "ODVariable":
264266
def __len__(self) -> int:
265267
return len(self.subindices)
266268

267-
def __iter__(self) -> Iterable[int]:
269+
def __iter__(self) -> Iterator[int]:
268270
return iter(sorted(self.subindices))
269271

270-
def __eq__(self, other: "ODArray") -> bool:
272+
def __eq__(self, other: ODArray) -> bool:
271273
return self.index == other.index
272274

273-
def add_member(self, variable: "ODVariable") -> None:
275+
def add_member(self, variable: ODVariable) -> None:
274276
"""Adds a :class:`~canopen.objectdictionary.ODVariable` to the record."""
275277
variable.parent = self
276278
self.subindices[variable.subindex] = variable
@@ -348,7 +350,7 @@ def qualname(self) -> str:
348350
return f"{self.parent.name}.{self.name}"
349351
return self.name
350352

351-
def __eq__(self, other: "ODVariable") -> bool:
353+
def __eq__(self, other: ODVariable) -> bool:
352354
return (self.index == other.index and
353355
self.subindex == other.subindex)
354356

canopen/pdo/base.py

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
from __future__ import annotations
12
import threading
23
import math
3-
from typing import Callable, Dict, Iterable, List, Optional, Union
4+
from typing import Callable, Dict, Iterator, List, Optional, Union, TYPE_CHECKING
45
from collections.abc import Mapping
56
import logging
67
import binascii
@@ -9,6 +10,12 @@
910
from canopen import objectdictionary
1011
from canopen import variable
1112

13+
if TYPE_CHECKING:
14+
from canopen.network import Network
15+
from canopen import LocalNode, RemoteNode
16+
from canopen.pdo import RPDO, TPDO
17+
from canopen.sdo import SdoRecord
18+
1219
PDO_NOT_VALID = 1 << 31
1320
RTR_NOT_ALLOWED = 1 << 30
1421

@@ -22,10 +29,10 @@ class PdoBase(Mapping):
2229
Parent object associated with this PDO instance
2330
"""
2431

25-
def __init__(self, node):
26-
self.network = None
27-
self.map = None # instance of PdoMaps
28-
self.node = node
32+
def __init__(self, node: Union[LocalNode, RemoteNode]):
33+
self.network: Optional[Network] = None
34+
self.map: Optional[PdoMaps] = None
35+
self.node: Union[LocalNode, RemoteNode] = node
2936

3037
def __iter__(self):
3138
return iter(self.map)
@@ -131,7 +138,7 @@ def __init__(self, com_offset, map_offset, pdo_node: PdoBase, cob_base=None):
131138
:param pdo_node:
132139
:param cob_base:
133140
"""
134-
self.maps: Dict[int, "PdoMap"] = {}
141+
self.maps: Dict[int, PdoMap] = {}
135142
for map_no in range(512):
136143
if com_offset + map_no in pdo_node.node.object_dictionary:
137144
new_map = PdoMap(
@@ -143,10 +150,10 @@ def __init__(self, com_offset, map_offset, pdo_node: PdoBase, cob_base=None):
143150
new_map.predefined_cob_id = cob_base + map_no * 0x100 + pdo_node.node.id
144151
self.maps[map_no + 1] = new_map
145152

146-
def __getitem__(self, key: int) -> "PdoMap":
153+
def __getitem__(self, key: int) -> PdoMap:
147154
return self.maps[key]
148155

149-
def __iter__(self) -> Iterable[int]:
156+
def __iter__(self) -> Iterator[int]:
150157
return iter(self.maps)
151158

152159
def __len__(self) -> int:
@@ -157,9 +164,9 @@ class PdoMap:
157164
"""One message which can have up to 8 bytes of variables mapped."""
158165

159166
def __init__(self, pdo_node, com_record, map_array):
160-
self.pdo_node = pdo_node
161-
self.com_record = com_record
162-
self.map_array = map_array
167+
self.pdo_node: Union[TPDO, RPDO] = pdo_node
168+
self.com_record: SdoRecord = com_record
169+
self.map_array: SdoRecord = map_array
163170
#: If this map is valid
164171
self.enabled: bool = False
165172
#: COB-ID for this PDO
@@ -177,7 +184,7 @@ def __init__(self, pdo_node, com_record, map_array):
177184
#: Ignores SYNC objects up to this SYNC counter value (optional)
178185
self.sync_start_value: Optional[int] = None
179186
#: List of variables mapped to this PDO
180-
self.map: List["PdoVariable"] = []
187+
self.map: List[PdoVariable] = []
181188
self.length: int = 0
182189
#: Current message data
183190
self.data = bytearray()
@@ -214,7 +221,7 @@ def __getitem_by_name(self, value):
214221
raise KeyError(f"{value} not found in map. Valid entries are "
215222
f"{', '.join(valid_values)}")
216223

217-
def __getitem__(self, key: Union[int, str]) -> "PdoVariable":
224+
def __getitem__(self, key: Union[int, str]) -> PdoVariable:
218225
if isinstance(key, int):
219226
# there is a maximum available of 8 slots per PDO map
220227
if key in range(0, 8):
@@ -228,7 +235,7 @@ def __getitem__(self, key: Union[int, str]) -> "PdoVariable":
228235
var = self.__getitem_by_name(key)
229236
return var
230237

231-
def __iter__(self) -> Iterable["PdoVariable"]:
238+
def __iter__(self) -> Iterator[PdoVariable]:
232239
return iter(self.map)
233240

234241
def __len__(self) -> int:
@@ -303,7 +310,7 @@ def on_message(self, can_id, data, timestamp):
303310
for callback in self.callbacks:
304311
callback(self)
305312

306-
def add_callback(self, callback: Callable[["PdoMap"], None]) -> None:
313+
def add_callback(self, callback: Callable[[PdoMap], None]) -> None:
307314
"""Add a callback which will be called on receive.
308315
309316
:param callback:
@@ -451,7 +458,7 @@ def add_variable(
451458
index: Union[str, int],
452459
subindex: Union[str, int] = 0,
453460
length: Optional[int] = None,
454-
) -> "PdoVariable":
461+
) -> PdoVariable:
455462
"""Add a variable from object dictionary as the next entry.
456463
457464
:param index: Index of variable as name or number
@@ -544,7 +551,7 @@ class PdoVariable(variable.Variable):
544551

545552
def __init__(self, od: objectdictionary.ODVariable):
546553
#: PDO object that is associated with this ODVariable Object
547-
self.pdo_parent = None
554+
self.pdo_parent: Optional[PdoMap] = None
548555
#: Location of variable in the message in bits
549556
self.offset = None
550557
self.length = len(od)

0 commit comments

Comments
 (0)