5
5
"""
6
6
7
7
import gc
8
+ import sys
8
9
import time
10
+ import traceback
9
11
import unittest
12
+ from threading import Thread
10
13
from time import sleep
11
14
from typing import List
12
15
from unittest .mock import MagicMock
@@ -87,6 +90,8 @@ def test_removing_bus_tasks(self):
87
90
# Note calling task.stop will remove the task from the Bus's internal task management list
88
91
task .stop ()
89
92
93
+ self .join_threads ([task .thread for task in tasks ], 5.0 )
94
+
90
95
assert len (bus ._periodic_tasks ) == 0
91
96
bus .shutdown ()
92
97
@@ -115,8 +120,7 @@ def test_managed_tasks(self):
115
120
for task in tasks :
116
121
task .stop ()
117
122
118
- for task in tasks :
119
- assert task .thread .join (5.0 ) is None , "Task didn't stop before timeout"
123
+ self .join_threads ([task .thread for task in tasks ], 5.0 )
120
124
121
125
bus .shutdown ()
122
126
@@ -142,9 +146,7 @@ def test_stopping_perodic_tasks(self):
142
146
143
147
# stop the other half using the bus api
144
148
bus .stop_all_periodic_tasks (remove_tasks = False )
145
-
146
- for task in tasks :
147
- assert task .thread .join (5.0 ) is None , "Task didn't stop before timeout"
149
+ self .join_threads ([task .thread for task in tasks ], 5.0 )
148
150
149
151
# Tasks stopped via `stop_all_periodic_tasks` with remove_tasks=False should
150
152
# still be associated with the bus (e.g. for restarting)
@@ -161,7 +163,7 @@ def test_restart_perodic_tasks(self):
161
163
is_extended_id = False , arbitration_id = 0x123 , data = [0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 ]
162
164
)
163
165
164
- def _read_all_messages (_bus : can .interfaces .virtual .VirtualBus ) -> None :
166
+ def _read_all_messages (_bus : " can.interfaces.virtual.VirtualBus" ) -> None :
165
167
sleep (safe_timeout )
166
168
while not _bus .queue .empty ():
167
169
_bus .recv (timeout = period )
@@ -207,9 +209,8 @@ def _read_all_messages(_bus: can.interfaces.virtual.VirtualBus) -> None:
207
209
208
210
# Stop all tasks and wait for the thread to exit
209
211
bus .stop_all_periodic_tasks ()
210
- if isinstance (task , can .broadcastmanager .ThreadBasedCyclicSendTask ):
211
- # Avoids issues where the thread is still running when the bus is shutdown
212
- task .thread .join (safe_timeout )
212
+ # Avoids issues where the thread is still running when the bus is shutdown
213
+ self .join_threads ([task .thread ], 5.0 )
213
214
214
215
@unittest .skipIf (IS_CI , "fails randomly when run on CI server" )
215
216
def test_thread_based_cyclic_send_task (self ):
@@ -288,6 +289,27 @@ def increment_first_byte(msg: can.Message) -> None:
288
289
self .assertEqual (b"\x06 \x00 \x00 \x00 \x00 \x00 \x00 \x00 " , bytes (msg_list [5 ].data ))
289
290
self .assertEqual (b"\x07 \x00 \x00 \x00 \x00 \x00 \x00 \x00 " , bytes (msg_list [6 ].data ))
290
291
292
+ @staticmethod
293
+ def join_threads (threads : List [Thread ], timeout : float ) -> None :
294
+ stuck_threads : List [Thread ] = []
295
+ t0 = time .perf_counter ()
296
+ for thread in threads :
297
+ time_left = timeout - (time .perf_counter () - t0 )
298
+ if time_left > 0.0 :
299
+ thread .join (time_left )
300
+ if thread .is_alive ():
301
+ if platform .python_implementation () == "CPython" :
302
+ # print thread frame to help with debugging
303
+ frame = sys ._current_frames ()[thread .ident ]
304
+ traceback .print_stack (frame , file = sys .stderr )
305
+ stuck_threads .append (thread )
306
+ if stuck_threads :
307
+ err_message = (
308
+ f"Threads did not stop within { timeout :.1f} seconds: "
309
+ f"[{ ', ' .join ([str (t ) for t in stuck_threads ])} ]"
310
+ )
311
+ raise RuntimeError (err_message )
312
+
291
313
292
314
if __name__ == "__main__" :
293
315
unittest .main ()
0 commit comments