Skip to content

Commit 5d2dd47

Browse files
authored
feat(retrieve): Use input grib as target (#88)
* feat: use input grib as retrieve target * feat: retrieve default output to stdout * fix: allow for grib input to be a nested dict * fix: create input object * fix: handle stdout correctly * refactor
1 parent 1a5b18e commit 5d2dd47

File tree

1 file changed

+14
-3
lines changed

1 file changed

+14
-3
lines changed

src/anemoi/inference/commands/retrieve.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@
99

1010

1111
import json
12+
import sys
1213

1314
from earthkit.data.utils.dates import to_datetime
1415

1516
from ..config import load_config
17+
from ..inputs.grib import GribInput
1618
from ..inputs.mars import postproc
1719
from ..runners.default import DefaultRunner
1820
from . import Command
@@ -28,7 +30,7 @@ def add_arguments(self, command_parser):
2830
command_parser.add_argument("config", type=str, help="Path to checkpoint")
2931
command_parser.add_argument("--defaults", action="append", help="Sources of default values.")
3032
command_parser.add_argument("--date", type=str, help="Date")
31-
command_parser.add_argument("--output", type=str, help="Output file")
33+
command_parser.add_argument("--output", type=str, default=None, help="Output file")
3234
command_parser.add_argument("--staging-dates", type=str, help="Path to a file with staging dates")
3335
command_parser.add_argument("--extra", action="append", help="Additional request values. Can be repeated")
3436
command_parser.add_argument("--retrieve-fields-type", type=str, help="Type of fields to retrieve")
@@ -60,6 +62,11 @@ def run(self, args):
6062

6163
extra = postproc(grid, area)
6264

65+
# so that the user does not need to pass --extra target=path when the input file is already in the config
66+
input = runner.create_input()
67+
if isinstance(input, GribInput):
68+
extra["target"] = input.path
69+
6370
for r in args.extra or []:
6471
k, v = r.split("=")
6572
extra[k] = v
@@ -83,8 +90,12 @@ def run(self, args):
8390
r.update(extra)
8491
requests.append(r)
8592

86-
with open(args.output, "w") as f:
87-
json.dump(requests, f, indent=4)
93+
if args.output and args.output != "-":
94+
with open(args.output, "w") as f:
95+
json.dump(requests, f, indent=4)
96+
return
97+
98+
json.dump(requests, sys.stdout, indent=4)
8899

89100

90101
command = RetrieveCmd

0 commit comments

Comments
 (0)