Skip to content

Commit b848f5d

Browse files
yaugenst-flexmomchil-flex
authored andcommitted
enh[web]: make pay type selection case insensitive
1 parent 2916702 commit b848f5d

File tree

7 files changed

+92
-23
lines changed

7 files changed

+92
-23
lines changed

tests/test_web/test_tidy3d_task.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,67 @@ def test_submit(set_api_key):
256256
# monitor(TASK_ID, True)
257257

258258

259+
@responses.activate
260+
def test_pay_type_case_insensitivity(set_api_key):
261+
"""Test PayType enum's case-insensitive behavior with different string formats."""
262+
project_id = "1234"
263+
TASK_ID = "5678"
264+
task_name = "test pay type"
265+
266+
responses.add(
267+
responses.GET,
268+
f"{Env.current.web_api_endpoint}/tidy3d/project",
269+
match=[matchers.query_param_matcher({"projectName": "test pay type folder"})],
270+
json={"data": {"projectId": project_id, "projectName": "test pay type folder"}},
271+
status=200,
272+
)
273+
responses.add(
274+
responses.POST,
275+
f"{Env.current.web_api_endpoint}/tidy3d/projects/{project_id}/tasks",
276+
json={
277+
"data": {
278+
"taskId": TASK_ID,
279+
"taskName": task_name,
280+
"createdAt": "2022-01-01T00:00:00.000Z",
281+
}
282+
},
283+
status=200,
284+
)
285+
286+
responses.add(
287+
responses.POST,
288+
f"{Env.current.web_api_endpoint}/tidy3d/tasks/{TASK_ID}/submit",
289+
json={
290+
"data": {
291+
"taskId": TASK_ID,
292+
"taskName": task_name,
293+
"createdAt": "2022-01-01T00:00:00.000Z",
294+
"taskBlockInfo": {
295+
"chargeType": "free",
296+
"maxFreeCount": 20,
297+
"maxGridPoints": 1000,
298+
"maxTimeSteps": 1000,
299+
},
300+
}
301+
},
302+
status=200,
303+
)
304+
305+
task = SimulationTask.create(TaskType.FDTD, task_name, "test pay type folder")
306+
307+
valid_pay_types = [
308+
"auto",
309+
"AUTO",
310+
PayType.AUTO,
311+
"credits",
312+
"CREDITS",
313+
PayType.CREDITS,
314+
]
315+
316+
for pay_type in valid_pay_types:
317+
task.submit(pay_type=pay_type)
318+
319+
259320
@responses.activate
260321
def test_estimate_cost(set_api_key):
261322
TASK_ID = "3eb06d16-208b-487b-864b-e9b1d3e010a7"

tidy3d/web/api/asynchronous.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Interface to run several jobs in batch using simplified syntax."""
22

3-
from typing import Dict, List, Literal
3+
from typing import Dict, List, Literal, Union
44

55
from ...log import log
66
from ..core.types import PayType
@@ -18,7 +18,7 @@ def run_async(
1818
simulation_type: str = "tidy3d",
1919
parent_tasks: Dict[str, List[str]] = None,
2020
reduce_simulation: Literal["auto", True, False] = "auto",
21-
pay_type: PayType = PayType.AUTO,
21+
pay_type: Union[PayType, str] = PayType.AUTO,
2222
) -> BatchData:
2323
"""Submits a set of Union[:class:`.Simulation`, :class:`.HeatSimulation`, :class:`.EMESimulation`] objects to server,
2424
starts running, monitors progress, downloads, and loads results as a :class:`.BatchData` object.
@@ -42,7 +42,7 @@ def run_async(
4242
If ``True``, will print progressbars and status, otherwise, will run silently.
4343
reduce_simulation: Literal["auto", True, False] = "auto"
4444
Whether to reduce structures in the simulation to the simulation domain only. Note: currently only implemented for the mode solver.
45-
pay_type: PayType = PayType.AUTO
45+
pay_type: Union[PayType, str] = PayType.AUTO
4646
Specify the payment method.
4747
4848
Returns
@@ -60,7 +60,6 @@ def run_async(
6060
:class:`Batch`
6161
Interface for submitting several :class:`Simulation` objects to sever.
6262
"""
63-
6463
if simulation_type is None:
6564
simulation_type = "tidy3d"
6665

tidy3d/web/api/autograd/autograd.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def run(
107107
local_gradient: bool = LOCAL_GRADIENT,
108108
max_num_adjoint_per_fwd: int = MAX_NUM_ADJOINT_PER_FWD,
109109
reduce_simulation: Literal["auto", True, False] = "auto",
110-
pay_type: PayType = PayType.AUTO,
110+
pay_type: typing.Union[PayType, str] = PayType.AUTO,
111111
) -> SimulationDataType:
112112
"""
113113
Submits a :class:`.Simulation` to server, starts running, monitors progress, downloads,
@@ -145,7 +145,7 @@ def run(
145145
Maximum number of adjoint simulations allowed to run automatically.
146146
reduce_simulation: Literal["auto", True, False] = "auto"
147147
Whether to reduce structures in the simulation to the simulation domain only. Note: currently only implemented for the mode solver.
148-
pay_type: PayType = AUTO
148+
pay_type: typing.Union[PayType, str] = PayType.AUTO
149149
Which method to pay for the simulation.
150150
Returns
151151
-------
@@ -191,7 +191,6 @@ def run(
191191
:meth:`tidy3d.web.api.container.Batch.monitor`
192192
Monitor progress of each of the running tasks.
193193
"""
194-
195194
if is_valid_for_autograd(simulation):
196195
return _run(
197196
simulation=simulation,
@@ -241,7 +240,7 @@ def run_async(
241240
local_gradient: bool = LOCAL_GRADIENT,
242241
max_num_adjoint_per_fwd: int = MAX_NUM_ADJOINT_PER_FWD,
243242
reduce_simulation: Literal["auto", True, False] = "auto",
244-
pay_type: PayType = PayType.AUTO,
243+
pay_type: typing.Union[PayType, str] = PayType.AUTO,
245244
) -> BatchData:
246245
"""Submits a set of Union[:class:`.Simulation`, :class:`.HeatSimulation`, :class:`.EMESimulation`] objects to server,
247246
starts running, monitors progress, downloads, and loads results as a :class:`.BatchData` object.
@@ -270,7 +269,7 @@ def run_async(
270269
Maximum number of adjoint simulations allowed to run automatically.
271270
reduce_simulation: Literal["auto", True, False] = "auto"
272271
Whether to reduce structures in the simulation to the simulation domain only. Note: currently only implemented for the mode solver.
273-
pay_type: PayType = PayType.AUTO
272+
pay_type: typing.Union[PayType, str] = PayType.AUTO
274273
Specify the payment method.
275274
276275
Returns
@@ -288,7 +287,6 @@ def run_async(
288287
:class:`Batch`
289288
Interface for submitting several :class:`Simulation` objects to sever.
290289
"""
291-
292290
if is_valid_for_autograd_async(simulations):
293291
return _run_async(
294292
simulations=simulations,

tidy3d/web/api/mode.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import tempfile
88
import time
99
from datetime import datetime
10-
from typing import Callable, List, Optional
10+
from typing import Callable, List, Optional, Union
1111

1212
import pydantic.v1 as pydantic
1313
from botocore.exceptions import ClientError
@@ -56,7 +56,7 @@ def run(
5656
progress_callback_upload: Callable[[float], None] = None,
5757
progress_callback_download: Callable[[float], None] = None,
5858
reduce_simulation: Literal["auto", True, False] = "auto",
59-
pay_type: PayType = PayType.AUTO,
59+
pay_type: Union[PayType, str] = PayType.AUTO,
6060
) -> ModeSolverData:
6161
"""Submits a :class:`.ModeSolver` to server, starts running, monitors progress, downloads,
6262
and loads results as a :class:`.ModeSolverData` object.
@@ -82,14 +82,13 @@ def run(
8282
reduce_simulation : Literal["auto", True, False] = "auto"
8383
Restrict simulation to mode solver region. If "auto", then simulation is automatically
8484
restricted if it contains custom mediums.
85-
pay_type: PayType = PayType.AUTO
85+
pay_type: Union[PayType, str] = PayType.AUTO
8686
Which method to pay the simulation.
8787
Returns
8888
-------
8989
:class:`.ModeSolverData`
9090
Mode solver data with the calculated results.
9191
"""
92-
9392
log_level = "DEBUG" if verbose else "INFO"
9493
if verbose:
9594
console = get_logging_console()
@@ -466,13 +465,16 @@ def upload(
466465

467466
def submit(
468467
self,
469-
pay_type: PayType = PayType.AUTO,
468+
pay_type: Union[PayType, str] = PayType.AUTO,
470469
):
471470
"""Start the execution of this task.
472471
473472
The mode solver must be uploaded to the server with the :meth:`ModeSolverTask.upload` method
474473
before this step.
475474
"""
475+
# convert right before sending to API
476+
pay_type = PayType(pay_type) if not isinstance(pay_type, PayType) else pay_type
477+
476478
http.post(
477479
f"{MODESOLVER_API}/{self.task_id}/{self.solver_id}/run",
478480
{

tidy3d/web/api/webapi.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import tempfile
66
import time
77
from datetime import datetime, timedelta
8-
from typing import Callable, Dict, List
8+
from typing import Callable, Dict, List, Union
99

1010
import pytz
1111
from requests import HTTPError
@@ -86,7 +86,7 @@ def run(
8686
simulation_type: str = "tidy3d",
8787
parent_tasks: list[str] = None,
8888
reduce_simulation: Literal["auto", True, False] = "auto",
89-
pay_type: PayType = PayType.AUTO,
89+
pay_type: Union[PayType, str] = PayType.AUTO,
9090
) -> SimulationDataType:
9191
"""
9292
Submits a :class:`.Simulation` to server, starts running, monitors progress, downloads,
@@ -119,7 +119,7 @@ def run(
119119
worker group
120120
reduce_simulation : Literal["auto", True, False] = "auto"
121121
Whether to reduce structures in the simulation to the simulation domain only. Note: currently only implemented for the mode solver.
122-
pay_type: PayType = PayType.AUTO
122+
pay_type: Union[PayType, str] = PayType.AUTO
123123
Which method to pay the simulation.
124124
125125
Returns
@@ -376,7 +376,7 @@ def start(
376376
task_id: TaskId,
377377
solver_version: str = None,
378378
worker_group: str = None,
379-
pay_type: PayType = PayType.AUTO,
379+
pay_type: Union[PayType, str] = PayType.AUTO,
380380
) -> None:
381381
"""Start running the simulation associated with task.
382382
@@ -389,7 +389,7 @@ def start(
389389
target solver version.
390390
worker_group: str = None
391391
worker group
392-
pay_type: PayType = PayType.AUTO
392+
pay_type: Union[PayType, str] = PayType.AUTO
393393
Which method to pay the simulation
394394
Note
395395
----

tidy3d/web/core/task_core.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import pathlib
77
import tempfile
88
from datetime import datetime
9-
from typing import Callable, List, Optional, Tuple
9+
from typing import Callable, List, Optional, Tuple, Union
1010

1111
import pydantic.v1 as pd
1212
from botocore.exceptions import ClientError
@@ -413,7 +413,7 @@ def submit(
413413
self,
414414
solver_version: str = None,
415415
worker_group: str = None,
416-
pay_type: PayType = PayType.AUTO,
416+
pay_type: Union[PayType, str] = PayType.AUTO,
417417
):
418418
"""Kick off this task.
419419
@@ -427,9 +427,10 @@ def submit(
427427
target solver version.
428428
worker_group: str = None
429429
worker group
430-
pay_type: PayType = PayType.AUTO
430+
pay_type: Union[PayType, str] = PayType.AUTO
431431
Which method to pay the simulation.
432432
"""
433+
pay_type = PayType(pay_type) if not isinstance(pay_type, PayType) else pay_type
433434

434435
if solver_version:
435436
protocol_version = None

tidy3d/web/core/types.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,11 @@ class TaskType(str, Enum):
5959
class PayType(str, Enum):
6060
CREDITS = "FLEX_CREDIT"
6161
AUTO = "AUTO"
62+
63+
@classmethod
64+
def _missing_(cls, value: object) -> PayType:
65+
if isinstance(value, str):
66+
key = value.strip().replace(" ", "_").upper()
67+
if key in cls.__members__:
68+
return cls.__members__[key]
69+
return super()._missing_(value)

0 commit comments

Comments
 (0)