Skip to content

Commit 8dd8b4f

Browse files
committed
Cleanup for mypy
1 parent 8582f09 commit 8dd8b4f

File tree

1 file changed

+56
-50
lines changed

1 file changed

+56
-50
lines changed

can/io/mf4.py

Lines changed: 56 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@
55
the ASAM MDF standard (see https://www.asam.net/standards/detail/mdf/)
66
"""
77

8+
import abc
89
import logging
910
from datetime import datetime
1011
from hashlib import md5
1112
from io import BufferedIOBase, BytesIO
1213
from pathlib import Path
13-
from typing import Any, BinaryIO, Generator, Iterable, Optional, Union, cast
14+
from typing import Any, BinaryIO, Dict, Generator, Iterator, List, Optional, Union, cast
1415

1516
from ..message import Message
1617
from ..typechecking import StringPathLike
@@ -70,6 +71,8 @@
7071
)
7172
except ImportError:
7273
asammdf = None
74+
MDF4 = None
75+
Signal = None
7376

7477

7578
CAN_MSG_EXT = 0x80000000
@@ -266,60 +269,63 @@ def on_message_received(self, msg: Message) -> None:
266269
self._rtr_buffer = np.zeros(1, dtype=RTR_DTYPE)
267270

268271

269-
class MF4Reader(BinaryIOMessageReader):
272+
class FrameIterator(object, metaclass=abc.ABCMeta):
270273
"""
271-
Iterator of CAN messages from a MF4 logging file.
272-
273-
The MF4Reader only supports MF4 files with CAN bus logging.
274+
Iterator helper class for common handling among CAN DataFrames, ErrorFrames and RemoteFrames.
274275
"""
275276

276-
# NOTE: Readout based on the bus logging code from asammdf GUI
277+
# Number of records to request for each asammdf call
278+
_chunk_size = 1000
277279

278-
class FrameIterator(object):
279-
"""
280-
Iterator helper class for common handling among CAN DataFrames, ErrorFrames and RemoteFrames.
281-
"""
280+
def __init__(self, mdf: MDF4, group_index: int, start_timestamp: float, name: str):
281+
self._mdf = mdf
282+
self._group_index = group_index
283+
self._start_timestamp = start_timestamp
284+
self._name = name
282285

283-
# Number of records to request for each asammdf call
284-
_chunk_size = 1000
286+
# Extract names
287+
channel_group: ChannelGroup = self._mdf.groups[self._group_index]
285288

286-
def __init__(
287-
self, mdf: MDF, group_index: int, start_timestamp: float, name: str
288-
):
289-
self._mdf = mdf
290-
self._group_index = group_index
291-
self._start_timestamp = start_timestamp
292-
self._name = name
289+
self._channel_names = []
293290

294-
# Extract names
295-
channel_group: ChannelGroup = self._mdf.groups[self._group_index]
291+
for channel in channel_group.channels:
292+
if str(channel.name).startswith(f"{self._name}."):
293+
self._channel_names.append(channel.name)
296294

297-
self._channel_names = []
295+
return
298296

299-
for channel in channel_group.channels:
300-
if str(channel.name).startswith(f"{self._name}."):
301-
self._channel_names.append(channel.name)
297+
def _get_data(self, current_offset: int) -> Signal:
298+
# NOTE: asammdf suggests using select instead of get. Select seem to miss converting some channels which
299+
# get does convert as expected.
300+
data_raw = self._mdf.get(
301+
self._name,
302+
self._group_index,
303+
record_offset=current_offset,
304+
record_count=self._chunk_size,
305+
raw=False,
306+
)
302307

303-
return
308+
return data_raw
304309

305-
def _get_data(self, current_offset: int) -> asammdf.Signal:
306-
# NOTE: asammdf suggests using select instead of get. Select seem to miss converting some channels which
307-
# get does convert as expected.
308-
data_raw = self._mdf.get(
309-
self._name,
310-
self._group_index,
311-
record_offset=current_offset,
312-
record_count=self._chunk_size,
313-
raw=False,
314-
)
310+
@abc.abstractmethod
311+
def __iter__(self) -> Generator[Message, None, None]:
312+
pass
315313

316-
return data_raw
314+
pass
317315

318-
pass
316+
317+
class MF4Reader(BinaryIOMessageReader):
318+
"""
319+
Iterator of CAN messages from a MF4 logging file.
320+
321+
The MF4Reader only supports MF4 files with CAN bus logging.
322+
"""
323+
324+
# NOTE: Readout based on the bus logging code from asammdf GUI
319325

320326
class CANDataFrameIterator(FrameIterator):
321327

322-
def __init__(self, mdf: MDF, group_index: int, start_timestamp: float):
328+
def __init__(self, mdf: MDF4, group_index: int, start_timestamp: float):
323329
super().__init__(mdf, group_index, start_timestamp, "CAN_DataFrame")
324330

325331
return
@@ -336,7 +342,7 @@ def __iter__(self) -> Generator[Message, None, None]:
336342
for i in range(len(data)):
337343
data_length = int(data["CAN_DataFrame.DataLength"][i])
338344

339-
kv = {
345+
kv: Dict[str, Any] = {
340346
"timestamp": float(data.timestamps[i]) + self._start_timestamp,
341347
"arbitration_id": int(data["CAN_DataFrame.ID"][i]) & 0x1FFFFFFF,
342348
"data": data["CAN_DataFrame.DataBytes"][i][
@@ -365,7 +371,7 @@ def __iter__(self) -> Generator[Message, None, None]:
365371

366372
class CANErrorFrameIterator(FrameIterator):
367373

368-
def __init__(self, mdf: MDF, group_index: int, start_timestamp: float):
374+
def __init__(self, mdf: MDF4, group_index: int, start_timestamp: float):
369375
super().__init__(mdf, group_index, start_timestamp, "CAN_ErrorFrame")
370376

371377
return
@@ -380,7 +386,7 @@ def __iter__(self) -> Generator[Message, None, None]:
380386
names = data.samples[0].dtype.names
381387

382388
for i in range(len(data)):
383-
kv = {
389+
kv: Dict[str, Any] = {
384390
"timestamp": float(data.timestamps[i]) + self._start_timestamp,
385391
"is_error_frame": True,
386392
}
@@ -422,7 +428,7 @@ def __iter__(self) -> Generator[Message, None, None]:
422428

423429
class CANRemoteFrameIterator(FrameIterator):
424430

425-
def __init__(self, mdf: MDF, group_index: int, start_timestamp: float):
431+
def __init__(self, mdf: MDF4, group_index: int, start_timestamp: float):
426432
super().__init__(mdf, group_index, start_timestamp, "CAN_RemoteFrame")
427433

428434
return
@@ -437,7 +443,7 @@ def __iter__(self) -> Generator[Message, None, None]:
437443
names = data.samples[0].dtype.names
438444

439445
for i in range(len(data)):
440-
kv = {
446+
kv: Dict[str, Any] = {
441447
"timestamp": float(data.timestamps[i]) + self._start_timestamp,
442448
"arbitration_id": int(data["CAN_RemoteFrame.ID"][i])
443449
& 0x1FFFFFFF,
@@ -476,20 +482,20 @@ def __init__(
476482

477483
super().__init__(file, mode="rb")
478484

479-
self._mdf: MDF
485+
self._mdf: MDF4
480486
if isinstance(file, BufferedIOBase):
481-
self._mdf = MDF(BytesIO(file.read()))
487+
self._mdf = cast(MDF4, MDF(BytesIO(file.read())))
482488
else:
483-
self._mdf = MDF(file)
489+
self._mdf = cast(MDF4, MDF(file))
484490

485491
self._start_timestamp = self._mdf.header.start_time.timestamp()
486492

487-
def __iter__(self) -> Iterable[Message]:
493+
def __iter__(self) -> Iterator[Message]:
488494
import heapq
489495

490496
# To handle messages split over multiple channel groups, create a single iterator per channel group and merge
491497
# these iterators into a single iterator using heapq.
492-
iterators = []
498+
iterators: List[FrameIterator] = []
493499
for group_index, group in enumerate(self._mdf.groups):
494500
channel_group: ChannelGroup = group.channel_group
495501

@@ -536,7 +542,7 @@ def __iter__(self) -> Iterable[Message]:
536542
continue
537543

538544
# Create merged iterator over all the groups, using the timestamps as comparison key
539-
return heapq.merge(*iterators, key=lambda x: x.timestamp)
545+
return iter(heapq.merge(*iterators, key=lambda x: x.timestamp))
540546

541547
def stop(self) -> None:
542548
self._mdf.close()

0 commit comments

Comments
 (0)