Skip to content

Commit 4e78971

Browse files
b8raoultpre-commit-ci[bot]dietervdb-meteo
authored
fix: support for netcdf missing values (#214)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Dieter Van den Bleeken <dieter.vandenbleeken@meteo.be>
1 parent 6925115 commit 4e78971

File tree

10 files changed

+225
-220
lines changed

10 files changed

+225
-220
lines changed

.gitignore

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,6 @@ tmp/
120120
temp/
121121
logs/
122122
_dev/
123-
outputs
124123
*tmp_data/
125124

126125
# Project specific

src/anemoi/inference/context.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,16 @@
1818
from typing import List
1919
from typing import Optional
2020

21-
from anemoi.inference.input import Input
22-
from anemoi.inference.output import Output
2321
from anemoi.inference.processor import Processor
2422
from anemoi.inference.types import IntArray
2523

2624
if TYPE_CHECKING:
25+
from anemoi.inference.input import Input
26+
from anemoi.inference.output import Output
27+
2728
from .checkpoint import Checkpoint
2829
from .forcings import Forcings
2930

30-
3131
LOG = logging.getLogger(__name__)
3232

3333

@@ -64,7 +64,7 @@ def checkpoint(self) -> "Checkpoint":
6464
# expected to provide the forcings directly as input to the runner.
6565
##################################################################
6666

67-
def create_input(self) -> Input:
67+
def create_input(self) -> "Input":
6868
"""Creates an input object for the inference.
6969
7070
Returns
@@ -74,7 +74,7 @@ def create_input(self) -> Input:
7474
"""
7575
raise NotImplementedError()
7676

77-
def create_output(self) -> Output:
77+
def create_output(self) -> "Output":
7878
"""Creates an output object for the inference.
7979
8080
Returns

src/anemoi/inference/output.py

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,24 +180,98 @@ class ForwardOutput(Output):
180180
"""
181181

182182
def __init__(
183-
self, context: "Context", output_frequency: Optional[int] = None, write_initial_state: Optional[bool] = None
183+
self,
184+
context: "Context",
185+
output: dict,
186+
output_frequency: Optional[int] = None,
187+
write_initial_state: Optional[bool] = None,
184188
):
185189
"""Initialize the ForwardOutput object.
186190
187191
Parameters
188192
----------
189193
context : Context
190194
The context in which the output operates.
195+
output : dict
196+
The output configuration dictionary.
191197
output_frequency : Optional[int], optional
192198
The frequency at which to output states, by default None.
193199
write_initial_state : Optional[bool], optional
194200
Whether to write the initial state, by default None.
195201
"""
202+
203+
from anemoi.inference.outputs import create_output
204+
196205
super().__init__(context, output_frequency=None, write_initial_state=write_initial_state)
206+
207+
self.output = None if output is None else create_output(context, output)
208+
197209
if self.context.output_frequency is not None:
198210
LOG.warning("output_frequency is ignored for '%s'", self.__class__.__name__)
199211

200212
@cached_property
201213
def output_frequency(self) -> Optional[datetime.timedelta]:
202214
"""Get the output frequency."""
203215
return None
216+
217+
def modify_state(self, state: State) -> State:
218+
"""Modify the state before writing.
219+
220+
Parameters
221+
----------
222+
state : State
223+
The state to modify.
224+
225+
Returns
226+
-------
227+
State
228+
The modified state.
229+
"""
230+
return state
231+
232+
def open(self, state) -> None:
233+
"""Open the output for writing.
234+
Parameters
235+
----------
236+
state : State
237+
The initial state.
238+
"""
239+
self.output.open(self.modify_state(state))
240+
241+
def close(self) -> None:
242+
"""Close the output."""
243+
244+
self.output.close()
245+
246+
def write_initial_step(self, state: State) -> None:
247+
"""Write the initial step of the state.
248+
249+
Parameters
250+
----------
251+
state : State
252+
The state dictionary.
253+
"""
254+
state.setdefault("step", datetime.timedelta(0))
255+
256+
self.output.write_initial_state(self.modify_state(state))
257+
258+
def write_step(self, state: State) -> None:
259+
"""Write a step of the state.
260+
261+
Parameters
262+
----------
263+
state : State
264+
The state to write.
265+
"""
266+
self.output.write_state(self.modify_state(state))
267+
268+
def print_summary(self, depth: int = 0) -> None:
269+
"""Print a summary of the output.
270+
271+
Parameters
272+
----------
273+
depth : int, optional
274+
The depth of the summary, by default 0.
275+
"""
276+
super().print_summary(depth)
277+
self.output.print_summary(depth + 1)

src/anemoi/inference/outputs/apply_mask.py

Lines changed: 9 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,15 @@
1212

1313
from anemoi.inference.config import Configuration
1414
from anemoi.inference.context import Context
15-
from anemoi.inference.types import State
1615

17-
from ..output import ForwardOutput
18-
from . import create_output
1916
from . import output_registry
17+
from .masked import MaskedOutput
2018

2119
LOG = logging.getLogger(__name__)
2220

2321

2422
@output_registry.register("apply_mask")
25-
class ApplyMaskOutput(ForwardOutput):
23+
class ApplyMaskOutput(MaskedOutput):
2624
"""Apply mask output class.
2725
2826
Parameters
@@ -48,75 +46,10 @@ def __init__(
4846
output_frequency: Optional[int] = None,
4947
write_initial_state: Optional[bool] = None,
5048
) -> None:
51-
super().__init__(context, output_frequency=output_frequency, write_initial_state=write_initial_state)
52-
self.mask = self.checkpoint.load_supporting_array(mask)
53-
self.output = create_output(context, output)
54-
55-
def __repr__(self) -> str:
56-
"""Return a string representation of the ApplyMaskOutput object."""
57-
return f"ApplyMaskOutput({self.mask}, {self.output})"
58-
59-
def write_initial_step(self, state: State) -> None:
60-
"""Write the initial step of the state.
61-
62-
Parameters
63-
----------
64-
state : State
65-
The state dictionary.
66-
"""
67-
# Note: we foreward to 'state', so we write-up options again
68-
self.output.write_initial_state(self._apply_mask(state))
69-
70-
def write_step(self, state: State) -> None:
71-
"""Write a step of the state.
72-
73-
Parameters
74-
----------
75-
state : State
76-
The state dictionary.
77-
"""
78-
# Note: we foreward to 'state', so we write-up options again
79-
self.output.write_state(self._apply_mask(state))
80-
81-
def _apply_mask(self, state: State) -> State:
82-
"""Apply the mask to the state.
83-
84-
Parameters
85-
----------
86-
state : State
87-
The state dictionary.
88-
89-
Returns
90-
-------
91-
State
92-
The masked state dictionary.
93-
"""
94-
state = state.copy()
95-
state["fields"] = state["fields"].copy()
96-
state["latitudes"] = state["latitudes"][self.mask]
97-
state["longitudes"] = state["longitudes"][self.mask]
98-
99-
for field in state["fields"]:
100-
data = state["fields"][field]
101-
if data.ndim == 1:
102-
data = data[self.mask]
103-
else:
104-
data = data[..., self.mask]
105-
state["fields"][field] = data
106-
107-
return state
108-
109-
def close(self) -> None:
110-
"""Close the output."""
111-
self.output.close()
112-
113-
def print_summary(self, depth: int = 0) -> None:
114-
"""Print the summary of the output.
115-
116-
Parameters
117-
----------
118-
depth : int, optional
119-
The depth of the summary, by default 0.
120-
"""
121-
super().print_summary(depth)
122-
self.output.print_summary(depth + 1)
49+
super().__init__(
50+
context,
51+
mask=self.checkpoint.load_supporting_array(mask),
52+
output=output,
53+
output_frequency=output_frequency,
54+
write_initial_state=write_initial_state,
55+
)

src/anemoi/inference/outputs/extract_lam.py

Lines changed: 14 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,15 @@
1414

1515
from anemoi.inference.config import Configuration
1616
from anemoi.inference.context import Context
17-
from anemoi.inference.types import State
1817

19-
from ..output import ForwardOutput
20-
from . import create_output
2118
from . import output_registry
19+
from .masked import MaskedOutput
2220

2321
LOG = logging.getLogger(__name__)
2422

2523

2624
@output_registry.register("extract_lam")
27-
class ExtractLamOutput(ForwardOutput):
25+
class ExtractLamOutput(MaskedOutput):
2826
"""Extract LAM output class.
2927
3028
Parameters
@@ -50,91 +48,27 @@ def __init__(
5048
output_frequency: Optional[int] = None,
5149
write_initial_state: Optional[bool] = None,
5250
) -> None:
53-
super().__init__(context, output_frequency=output_frequency, write_initial_state=write_initial_state)
5451

55-
if "cutout_mask" in self.checkpoint.supporting_arrays:
52+
if "cutout_mask" in context.checkpoint.supporting_arrays:
5653
# Backwards compatibility
57-
mask = self.checkpoint.load_supporting_array("cutout_mask")
54+
mask = context.checkpoint.load_supporting_array("cutout_mask")
5855
points = slice(None, -np.sum(mask))
5956
else:
6057
if "lam_0" not in lam:
6158
raise NotImplementedError("Only lam_0 is supported")
6259

63-
if "lam_1/cutout_mask" in self.checkpoint.supporting_arrays:
60+
if "lam_1/cutout_mask" in context.checkpoint.supporting_arrays:
6461
raise NotImplementedError("Only lam_0 is supported")
6562

66-
mask = self.checkpoint.load_supporting_array(f"{lam}/cutout_mask")
63+
mask = context.checkpoint.load_supporting_array(f"{lam}/cutout_mask")
64+
6765
assert len(mask) == np.sum(mask)
6866
points = slice(None, np.sum(mask))
6967

70-
self.points = points
71-
self.output = create_output(context, output)
72-
73-
def __repr__(self) -> str:
74-
"""Return a string representation of the ExtractLamOutput object."""
75-
return f"ExtractLamOutput({self.points}, {self.output})"
76-
77-
def write_initial_state(self, state: State) -> None:
78-
"""Write the initial step of the state.
79-
80-
Parameters
81-
----------
82-
state : State
83-
The state dictionary.
84-
"""
85-
# Note: we foreward to 'state', so we write-up options again
86-
self.output.write_initial_state(self._apply_mask(state))
87-
88-
def write_step(self, state: State) -> None:
89-
"""Write a step of the state.
90-
91-
Parameters
92-
----------
93-
state : State
94-
The state dictionary.
95-
"""
96-
# Note: we foreward to 'state', so we write-up options again
97-
self.output.write_state(self._apply_mask(state))
98-
99-
def _apply_mask(self, state: State) -> State:
100-
"""Apply the mask to the state.
101-
102-
Parameters
103-
----------
104-
state : State
105-
The state dictionary.
106-
107-
Returns
108-
-------
109-
State
110-
The masked state dictionary.
111-
"""
112-
state = state.copy()
113-
state["fields"] = state["fields"].copy()
114-
state["latitudes"] = state["latitudes"][self.points]
115-
state["longitudes"] = state["longitudes"][self.points]
116-
117-
for field in state["fields"]:
118-
data = state["fields"][field]
119-
if data.ndim == 1:
120-
data = data[self.points]
121-
else:
122-
data = data[..., self.points]
123-
state["fields"][field] = data
124-
125-
return state
126-
127-
def close(self) -> None:
128-
"""Close the output."""
129-
self.output.close()
130-
131-
def print_summary(self, depth: int = 0) -> None:
132-
"""Print the summary of the output.
133-
134-
Parameters
135-
----------
136-
depth : int, optional
137-
The depth of the summary, by default 0.
138-
"""
139-
super().print_summary(depth)
140-
self.output.print_summary(depth + 1)
68+
super().__init__(
69+
context,
70+
mask=points,
71+
output=output,
72+
output_frequency=output_frequency,
73+
write_initial_state=write_initial_state,
74+
)

0 commit comments

Comments
 (0)