@@ -19,17 +19,22 @@ def __init__(self, *args, **kwargs):
19
19
self .iteration_task = asyncio .create_task (self ._run_iteration ())
20
20
21
21
async def _run_iteration (self ):
22
- while True :
23
- async with self .condition :
24
- await self .condition .wait_for (lambda : len (self .jobs ) > 0 )
25
- results = self .generator .iterate ()
26
- for result in results :
27
- job = result ["job" ]
28
- async_job = self .jobs [job ]
29
- await async_job .put_result (result )
30
- if result ["eos" ]:
31
- del self .jobs [job ]
32
- await asyncio .sleep (0 )
22
+ try :
23
+ while True :
24
+ async with self .condition :
25
+ await self .condition .wait_for (lambda : len (self .jobs ) > 0 )
26
+ results = self .generator .iterate ()
27
+ for result in results :
28
+ job = result ["job" ]
29
+ async_job = self .jobs [job ]
30
+ await async_job .put_result (result )
31
+ if result ["eos" ]:
32
+ del self .jobs [job ]
33
+ await asyncio .sleep (0 )
34
+ except Exception as e :
35
+ # If the generator throws an exception it won't pertain to any one ongoing job, so push it to all of them
36
+ for async_job in self .jobs .values ():
37
+ await async_job .put_result (e )
33
38
34
39
def enqueue (self , job : ExLlamaV2DynamicJobAsync ):
35
40
assert job .job not in self .jobs
@@ -75,6 +80,8 @@ async def put_result(self, result):
75
80
async def __aiter__ (self ):
76
81
while True :
77
82
result = await self .queue .get ()
83
+ if isinstance (result , Exception ):
84
+ raise result
78
85
yield result
79
86
if result ["eos" ]:
80
87
break
0 commit comments