|
1 |
| -from typing import Iterable, Union |
| 1 | +from typing import Iterable, Optional, Union |
2 | 2 |
|
3 | 3 | import requests
|
4 | 4 |
|
| 5 | +from .retry import RegularPeriodRetry, RetryMechanism |
| 6 | + |
5 | 7 | from .submission import Submission
|
6 | 8 |
|
7 | 9 |
|
@@ -160,6 +162,36 @@ def get_submissions(
|
160 | 162 | for submission, attrs in zip(submissions, resp.json()["submissions"]):
|
161 | 163 | submission.set_attributes(attrs)
|
162 | 164 |
|
| 165 | + def wait( |
| 166 | + self, |
| 167 | + submissions: Union[Submission, list[Submission]], |
| 168 | + *, |
| 169 | + retry_mechanism: Optional[RetryMechanism] = None, |
| 170 | + ): |
| 171 | + if retry_mechanism is None: |
| 172 | + retry_mechanism = RegularPeriodRetry() |
| 173 | + |
| 174 | + if not isinstance(submissions, (list, tuple)): |
| 175 | + submissions = [submissions] |
| 176 | + |
| 177 | + submissions_to_check = { |
| 178 | + submission.token: submission for submission in submissions |
| 179 | + } |
| 180 | + |
| 181 | + while len(submissions_to_check) > 0 and not retry_mechanism.is_done(): |
| 182 | + self.get_submissions(submissions_to_check.values()) |
| 183 | + for token in list(submissions_to_check): |
| 184 | + submission = submissions_to_check[token] |
| 185 | + if submission.is_done(): |
| 186 | + submissions_to_check.pop(token) |
| 187 | + |
| 188 | + # Don't wait if there is no submissions to check for anymore. |
| 189 | + if len(submissions_to_check) == 0: |
| 190 | + break |
| 191 | + |
| 192 | + retry_mechanism.wait() |
| 193 | + retry_mechanism.step() |
| 194 | + |
163 | 195 |
|
164 | 196 | class ATD(Client):
|
165 | 197 | def __init__(self, endpoint, host_header_value, api_key):
|
|
0 commit comments