Skip to content

Commit 6765316

Browse files
dietervdb-meteogmertesfrazane
authored
feat: inference with external graph. (#216)
## Description We add functionality so that in inference the model can use another graph than the one it was trained with. In this first implementation the graph has to be provided as a file on disk. This PR adds a new runner `runner: external_graph` that is an extension of the default runner. The code is based on a similar feature in bris-inference: https://github.com/metno/bris-inference/blob/main/bris/checkpoint.py#L185 The runner can be selected and set in the config as follows: ```yaml runner: external_graph: graph: path/to/graph.pt ``` For further options for the runner please consult the documentation. <!-- Provide a brief summary of the changes introduced in this pull request. --> ## Type of Change - [ ] Bug fix (non-breaking change which fixes an issue) - [x] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [x] Documentation update ## Issue Number <!-- Link the Issue number this change addresses, ideally in one of the "magic format" such as Closes #XYZ --> Closes #215 . <!-- Alternatively, explain the motivation behind the changes and the context in which they are being made. --> ## Code Compatibility - [x] I have performed a self-review of my code ### Code Performance and Testing - [ ] I have added tests that prove my fix is effective or that my feature works - [ ] I ran the [complete Pytest test](https://anemoi.readthedocs.io/projects/training/en/latest/dev/testing.html) suite locally, and they pass <!-- In case this affects the model sharding or other specific components please describe these here. --> ### Dependencies - [x] I have ensured that the code is still pip-installable after the changes and runs - [ ] I have tested that new dependencies themselves are pip-installable. <!-- List any new dependencies that are required for this change and the justification to add them. --> ### Documentation - [ ] My code follows the style guidelines of this project - [x] I have updated the documentation and docstrings to reflect the changes - [x] I have added comments to my code, particularly in hard-to-understand areas <!-- Describe any major updates to the documentation --> ## Additional Notes <!-- Include any additional information, caveats, or considerations that the reviewer should be aware of. --> <!-- readthedocs-preview anemoi-inference start --> ---- 📚 Documentation preview 📚: https://anemoi-inference--216.org.readthedocs.build/en/216/ <!-- readthedocs-preview anemoi-inference end --> --------- Co-authored-by: Gert Mertes <gert.mertes@ecmwf.int> Co-authored-by: Francesco Zanetta <zanetta.francesco@gmail.com>
1 parent bab2dbe commit 6765316

File tree

10 files changed

+276
-3
lines changed

10 files changed

+276
-3
lines changed

docs/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ You may also have to install pandoc on MacOS:
121121
:caption: Recipe Examples
122122

123123
usage/getting-started
124+
usage/external-graph
124125

125126
.. toctree::
126127
:maxdepth: 1

docs/usage/external-graph.rst

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
.. _usage-external-graph:
2+
3+
###################################
4+
Inference using an external graph
5+
###################################
6+
7+
Anemoi is a framework for building and running machine learning models
8+
based on graph neural networks (GNNs). One of the key features of such
9+
GNNS is that they can operate on arbitrary graphs. In particular it
10+
means one can train the model on one graph, but use it in inference on
11+
another graph. This way one can transfer the model to a different domain
12+
or dataset, without any fine tuning, or even change the scope of a
13+
model. For example using a model trained as a stretched grid as a
14+
limited area model (LAM) with boundary forcings in inference.
15+
16+
We should caution that such transfer of the model from one graph to
17+
another is not guaranteed to lead to good results. Still, it is a
18+
powerful tool to explore generalizability of the model or to test
19+
performance before starting fine tuning through transfer learning.
20+
21+
The ability to do inference with an alternative graph, or more precisely
22+
one 'external' to the checkpoint created in training, is supported by
23+
anemoi-inference through the ``external_graph`` runner.
24+
25+
This runner, and the graph it will use, can be specified in the config
26+
file as follows:
27+
28+
.. literalinclude:: yaml/external-graph1.yaml
29+
:language: yaml
30+
31+
In case one wants to run a model trained on a global dataset on a graph
32+
supported only on a limited area one needs to specify the
33+
``output_mask`` to be used. This mask selects the region on which the
34+
model will forecast and triggers boundary forcings to be applied when
35+
forecasting autoregressively towards later lead times. As in training,
36+
also in inference the output mask orginates from an attribute of the
37+
output nodes of the graph. It can be specified in the config file as
38+
follows:
39+
40+
.. literalinclude:: yaml/external-graph2.yaml
41+
:language: yaml
42+
43+
For LAM models the limited area among the input nodes of a larger
44+
dataset is often specified by the ``indices_connected_nodes`` attribute
45+
of the input nodes. Anemoi-inference will automatically update the
46+
dataloader to load only data in the limited area in case the external
47+
graph contains this attribute and was build using the same dataset as
48+
the one in the checkpoint.
49+
50+
In case one wants to work with a graph that was built on another dataset
51+
than that used in training, on should specify this in the config file as
52+
well:
53+
54+
.. literalinclude:: yaml/external-graph3.yaml
55+
:language: yaml
56+
57+
It should be emphasized that by using this runner the model will be
58+
rebuilt and for this reason will differ from the model stored in the
59+
checkpoint. To avoid unexpected results, there is a default check that
60+
ensures the model used in inference has the same weights, biases and
61+
normalizer values as that stored in the checkpoint. In case of a more
62+
adventurous use-case this check can be disabled through the config as:
63+
64+
.. literalinclude:: yaml/external-graph4.yaml
65+
:language: yaml

docs/usage/yaml/external-graph1.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
runner:
2+
external_graph:
3+
graph: path/to/graph.pt

docs/usage/yaml/external-graph2.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
runner:
2+
external_graph:
3+
graph: path/to/graph.pt
4+
output_mask:
5+
nodes_name: data # name of the output nodes of the graph
6+
attribute_name: cutout_mask # mask specifying the limited area among the output nodes

docs/usage/yaml/external-graph3.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
runner:
2+
external_graph:
3+
graph: path/to/graph.pt
4+
graph_dataset: path/to/graph_dataset.zarr
5+
# the above can be an anemoi-datasets.open_dataset argument as well,
6+
# rather than simply a path

docs/usage/yaml/external-graph4.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
runner:
2+
external_graph:
3+
graph: path/to/graph.pt
4+
check_state_dict: False

src/anemoi/inference/config/run.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class RunConfiguration(Configuration):
3131
checkpoint: Union[str, Dict[Literal["huggingface"], Union[Dict[str, Any], str]]]
3232
"""A path to an Anemoi checkpoint file."""
3333

34-
runner: str = "default"
34+
runner: Union[str, Dict[str, Any]] = "default"
3535
"""The runner to use."""
3636

3737
date: Union[str, int, datetime.datetime, None] = None

src/anemoi/inference/runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -401,7 +401,7 @@ def prepare_input_tensor(self, input_state: State, dtype: DTypeLike = np.float32
401401
shape=(
402402
self.checkpoint.multi_step_input,
403403
self.checkpoint.number_of_input_features,
404-
self.checkpoint.number_of_grid_points,
404+
input_state["latitudes"].size,
405405
),
406406
fill_value=np.nan,
407407
dtype=dtype,

src/anemoi/inference/runners/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,4 @@ def create_runner(config: Configuration, **kwargs: Any) -> Any:
3030
Any
3131
The created runner instance.
3232
"""
33-
return runner_registry.create(config.runner, config, **kwargs)
33+
return runner_registry.from_config(config.runner, config, **kwargs)
Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
import logging
2+
import os
3+
from copy import deepcopy
4+
from functools import cached_property
5+
from typing import Any
6+
7+
import torch
8+
from anemoi.datasets import open_dataset
9+
10+
from ..runners.default import DefaultRunner
11+
from . import runner_registry
12+
13+
LOG = logging.getLogger(__name__)
14+
15+
# Possibly move the function(s) below to anemoi-models or anemoi-utils since it could be used in transfer learning.
16+
17+
18+
def contains_any(key, specifications):
19+
contained = False
20+
for specification in specifications:
21+
if specification in key:
22+
contained = True
23+
break
24+
return contained
25+
26+
27+
def update_state_dict(
28+
model, external_state_dict, keywords="", ignore_mismatched_layers=False, ignore_additional_layers=False
29+
):
30+
"""Update the model's stated_dict with entries from an external state_dict. Only entries whose keys contain the specified keywords are considered."""
31+
32+
LOG.info("Updating model state dictionary.")
33+
34+
if isinstance(keywords, str):
35+
keywords = [keywords]
36+
37+
# select relevant part of external_state_dict
38+
reduced_state_dict = {k: v for k, v in external_state_dict.items() if contains_any(k, keywords)}
39+
model_state_dict = model.state_dict()
40+
41+
# check layers and their shapes
42+
for key in list(reduced_state_dict):
43+
if key not in model_state_dict:
44+
if ignore_additional_layers:
45+
LOG.info("Skipping injection of %s, which is not in the model.", key)
46+
del reduced_state_dict[key]
47+
else:
48+
raise AssertionError(f"Layer {key} not in model. Consider setting 'ignore_additional_layers = True'.")
49+
elif reduced_state_dict[key].shape != model_state_dict[key].shape:
50+
if ignore_mismatched_layers:
51+
LOG.info("Skipping injection of %s due to shape mismatch.", key)
52+
LOG.info("Model shape: %s", model_state_dict[key].shape)
53+
LOG.info("Provided shape: %s", reduced_state_dict[key].shape)
54+
del reduced_state_dict[key]
55+
else:
56+
raise AssertionError(
57+
"Mismatch in shape of %s. Consider setting 'ignore_mismatched_layers = True'.", key
58+
)
59+
60+
# update
61+
model.load_state_dict(reduced_state_dict, strict=False)
62+
return model
63+
64+
65+
@runner_registry.register("external_graph")
66+
class ExternalGraphRunner(DefaultRunner):
67+
"""Runner where the graph saved in the checkpoint is replaced by an externally provided one.
68+
Currently only supported as an extension of the default runner.
69+
"""
70+
71+
def __init__(
72+
self,
73+
config: dict,
74+
graph: str,
75+
output_mask: dict | None = {},
76+
graph_dataset: Any | None = None,
77+
check_state_dict: bool | None = True,
78+
) -> None:
79+
"""Initialize the ExternalGraphRunner.
80+
81+
Parameters
82+
----------
83+
config : Configuration
84+
Configuration for the runner.
85+
graph : str
86+
Path to the external graph.
87+
output_mask : dict | None
88+
Dictionary specifying the output mask.
89+
graph_dataset : Any | None
90+
Argument to open_dataset of anemoi-datasets that recreates the dataset used to build the data nodes of the graph.
91+
check_state_dict: bool | None
92+
Boolean specifying if reconstruction of statedict happens as expeceted.
93+
"""
94+
super().__init__(config)
95+
self.check_state_dict = check_state_dict
96+
self.graph_path = graph
97+
98+
# If graph was build on other dataset, we need to adapt the dataloader
99+
if graph_dataset is not None:
100+
graph_ds = open_dataset(graph_dataset)
101+
LOG.info(
102+
"The external graph was built using a different anemoi-dataset than that in the checkpoint. "
103+
"Patching metadata to ensure correct data loading."
104+
)
105+
self.checkpoint._metadata.patch(
106+
{
107+
"config": {"dataloader": {"dataset": graph_dataset}},
108+
"dataset": {"shape": graph_ds.shape},
109+
}
110+
)
111+
112+
# had to use private attributes because cached properties cause problems
113+
self.checkpoint._metadata._supporting_arrays = graph_ds.supporting_arrays()
114+
if "grid_indices" in self.checkpoint._metadata._supporting_arrays:
115+
num_grid_points = len(self.checkpoint._metadata._supporting_arrays["grid_indices"])
116+
else:
117+
num_grid_points = graph_ds.shape[-1]
118+
self.checkpoint._metadata.number_of_grid_points = num_grid_points
119+
120+
# Check if the external graph has the 'indices_connected_nodes' attribute
121+
# If so adapt dataloader and add supporting array
122+
data = self.checkpoint._metadata._config.graph.data
123+
assert data in self.graph.node_types, f"Node type {data} not found in external graph."
124+
if "indices_connected_nodes" in self.graph[data]:
125+
LOG.info(
126+
"The external graph has the 'indices_connected_nodes' attribute."
127+
"Patching metadata with MaskedGrid 'grid_indices' to ensure correct data loading."
128+
)
129+
self.checkpoint._metadata.patch(
130+
{
131+
"config": {
132+
"dataloader": {
133+
"grid_indices": {
134+
"_target_": "anemoi.training.data.grid_indices.MaskedGrid",
135+
"nodes_name": data,
136+
"node_attribute_name": "indices_connected_nodes",
137+
}
138+
}
139+
}
140+
}
141+
)
142+
LOG.info("Moving 'indices_connected_nodes' from external graph to supporting arrays as 'grid_indices'.")
143+
indices_connected_nodes = self.graph[data]["indices_connected_nodes"].numpy()
144+
self.checkpoint._supporting_arrays["grid_indices"] = indices_connected_nodes.squeeze()
145+
146+
if output_mask:
147+
nodes = output_mask["nodes_name"]
148+
attribute = output_mask["attribute_name"]
149+
self.checkpoint._supporting_arrays["output_mask"] = self.graph[nodes][attribute].numpy().squeeze()
150+
LOG.info(
151+
"Moving attribute '%s' of nodes '%s' from external graph to supporting arrays as 'output_mask'.",
152+
attribute,
153+
nodes,
154+
)
155+
156+
@cached_property
157+
def graph(self):
158+
graph_path = self.graph_path
159+
assert os.path.isfile(
160+
graph_path
161+
), f"No graph found at {graph_path}. An external graph needs to be specified in the config file for this runner."
162+
LOG.info("Loading external graph from path %s.", graph_path)
163+
return torch.load(graph_path, map_location="cpu", weights_only=False)
164+
165+
@cached_property
166+
def model(self):
167+
# load the model from the checkpoint
168+
device = self.device
169+
self.device = "cpu"
170+
model_instance = super().model
171+
state_dict_ckpt = deepcopy(model_instance.state_dict())
172+
173+
# rebuild the model with the new graph
174+
model_instance.graph_data = self.graph
175+
model_instance.config = self.checkpoint._metadata._config
176+
model_instance._build_model()
177+
178+
# reinstate the weights, biases and normalizer from the checkpoint
179+
# reinstating the normalizer is necessary for checkpoints that were created
180+
# using transfer learning, where the statistics as stored in the checkpoint
181+
# do not match the statistics used to build the normalizer in the checkpoint.
182+
model_instance = update_state_dict(
183+
model_instance, state_dict_ckpt, keywords=["bias", "weight", "processors.normalizer"]
184+
)
185+
186+
LOG.info("Successfully built model with external graph and reassigned model weights!")
187+
self.device = device
188+
return model_instance.to(self.device)

0 commit comments

Comments
 (0)