Skip to content

Commit 0767a87

Browse files
Merge pull request #361 from ecmwf/feature/fix-netcdf-xy-coordinate-getter
Fix netcdf xy coordinate access
2 parents cf038f5 + faf00e5 commit 0767a87

File tree

3 files changed

+148
-131
lines changed

3 files changed

+148
-131
lines changed

earthkit/data/readers/netcdf/dataset.py

Lines changed: 70 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,15 @@ def __getattr__(self, name):
4343
def bbox(self, variable):
4444
data_array = self[variable]
4545

46-
coords = self._get_xy_coords(data_array)
47-
key = ("bbox", tuple(coords))
46+
keys, coords = self._get_xy_coords(data_array)
47+
key = ("bbox", tuple(keys), tuple(coords))
4848
if key in self._cache:
4949
return self._cache[key]
5050

51-
lons = self._get_xy(data_array, "x", flatten=False)
52-
lats = self._get_xy(data_array, "y", flatten=False)
51+
# lons = self._get_xy(data_array, "x", flatten=False)
52+
# lats = self._get_xy(data_array, "y", flatten=False)
53+
54+
lats, lons = self._get_latlon(data_array, flatten=False)
5355

5456
north = np.amax(lats)
5557
west = np.amin(lons)
@@ -59,143 +61,86 @@ def bbox(self, variable):
5961
self._cache[key] = (north, west, south, east)
6062
return self._cache[key]
6163

62-
# dims = data_array.dims
63-
64-
# lat = dims[-2]
65-
# lon = dims[-1]
66-
67-
# key = ("bbox", lat, lon)
68-
# if key in self._cache:
69-
# return self._cache[key]
70-
71-
# lats, lons = self.grid_points(variable)
72-
# north = np.amax(lats)
73-
# west = np.amin(lons)
74-
# south = np.amin(lats)
75-
# east = np.amax(lons)
76-
77-
# self._cache[key] = (north, west, south, east)
78-
# return self._cache[key]
79-
80-
# if (lat, lon) not in self._bbox:
81-
# dims = data_array.dims
82-
83-
# latitude = data_array[lat]
84-
# longitude = data_array[lon]
85-
86-
# self._bbox[(lat, lon)] = (
87-
# np.amax(latitude.data),
88-
# np.amin(longitude.data),
89-
# np.amin(latitude.data),
90-
# np.amax(longitude.data),
91-
# )
92-
93-
# return self._bbox[(lat, lon)]
94-
95-
# def grid_points(self, variable):
96-
# data_array = self[variable]
97-
# dims = data_array.dims
98-
99-
# lat = dims[-2]
100-
# lon = dims[-1]
101-
102-
# key = ("grid_points", lat, lon)
103-
# if key in self._cache:
104-
# return self._cache[key]
105-
106-
# if "latitude" in self._ds and "longitude" in self._ds:
107-
# latitude = self._ds["latitude"]
108-
# longitude = self._ds["longitude"]
109-
110-
# if latitude.dims == (lat, lon) and longitude.dims == (lat, lon):
111-
# latitude = latitude.data
112-
# longitude = longitude.data
113-
# return latitude.flatten(), longitude.flatten()
114-
115-
# latitude = data_array[lat]
116-
# longitude = data_array[lon]
117-
118-
# lat, lon = np.meshgrid(latitude.data, longitude.data)
119-
120-
# self._cache[key] = lat.flatten(), lon.flatten()
121-
# return self._cache[key]
122-
123-
# def grid_points_xy(self, variable):
124-
# data_array = self[variable]
125-
# dims = data_array.dims
126-
127-
# lat = dims[-2]
128-
# lon = dims[-1]
129-
130-
# latitude = data_array[lat].data
131-
# longitude = data_array[lon].data
132-
133-
# print(latitude, longitude)
134-
135-
# lat, lon = np.meshgrid(latitude, longitude)
136-
137-
# return lat.flatten(), lon.flatten()
138-
# # return self._cache[key]
139-
140-
def _get_xy(self, data_array, axis, flatten=False, dtype=None):
141-
if axis not in ("x", "y"):
142-
raise ValueError(f"Invalid axis={axis}")
143-
144-
coords = self._get_xy_coords(data_array)
145-
key = ("grid_points", tuple(coords))
146-
if key in self._cache:
147-
points = self._cache[key]
148-
else:
149-
points = dict()
150-
keys = [x[0] for x in coords]
151-
coords = tuple([x[1] for x in coords])
152-
153-
if "latitude" in self._ds and "longitude" in self._ds:
154-
latitude = self._ds["latitude"]
155-
longitude = self._ds["longitude"]
156-
157-
if latitude.dims == coords and longitude.dims == coords:
158-
latitude = latitude.data
159-
longitude = longitude.data
160-
points["x"] = longitude
161-
points["y"] = latitude
162-
if not points:
163-
v0, v1 = data_array.coords[coords[0]], data_array.coords[coords[1]]
164-
points[keys[1]], points[keys[0]] = np.meshgrid(v1, v0)
165-
self._cache[key] = points
166-
167-
if flatten:
168-
points[axis] = points[axis].reshape(-1)
169-
if dtype is not None:
170-
return points[axis].astype(dtype)
171-
else:
172-
return points[axis]
173-
17464
def _get_xy_coords(self, data_array):
175-
c = []
176-
17765
if (
17866
len(data_array.dims) >= 2
17967
and data_array.dims[-1] in GEOGRAPHIC_COORDS["x"]
18068
and data_array.dims[-2] in GEOGRAPHIC_COORDS["y"]
18169
):
182-
return [("y", data_array.dims[-2]), ("x", data_array.dims[-1])]
70+
return ("y", "x"), (data_array.dims[-2], data_array.dims[-1])
18371

72+
keys = []
73+
coords = []
18474
axes = ("x", "y")
18575
for dim in data_array.dims:
18676
for ax in axes:
18777
candidates = GEOGRAPHIC_COORDS.get(ax, [])
18878
if dim in candidates:
189-
c.append((ax, dim))
79+
keys.append(ax)
80+
coords.append(dim)
19081
else:
19182
ax = data_array.coords[dim].attrs.get("axis", "").lower()
19283
if ax in axes:
193-
c.append([ax, dim])
194-
if len(c) == 2:
195-
return c
84+
keys.append(ax)
85+
coords.append(dim)
86+
if len(keys) == 2:
87+
return tuple(keys), tuple(coords)
19688

19789
for ax in axes:
198-
if ax not in [x[0] for x in c]:
90+
if ax not in keys:
19991
raise ValueError(f"No coordinate found with axis '{ax}'")
20092

201-
return c
93+
return keys, coords
94+
95+
def _get_xy(self, data_array, flatten=False, dtype=None):
96+
keys, coords = self._get_xy_coords(data_array)
97+
key = ("grid_points", tuple(keys), tuple(coords))
98+
99+
if key in self._cache:
100+
points = self._cache[key]
101+
else:
102+
points = dict()
103+
v0, v1 = data_array.coords[coords[0]], data_array.coords[coords[1]]
104+
points[keys[1]], points[keys[0]] = np.meshgrid(v1, v0)
105+
self._cache[key] = points
106+
107+
if flatten:
108+
points["x"] = points["x"].reshape(-1)
109+
points["y"] = points["y"].reshape(-1)
110+
111+
if dtype is not None:
112+
return points["x"].astype(dtype), points["y"].astype(dtype)
113+
else:
114+
return points["x"], points["y"]
115+
116+
def _get_latlon(self, data_array, flatten=False, dtype=None):
117+
keys, coords = self._get_xy_coords(data_array)
118+
119+
points = dict()
120+
if "latitude" in self._ds and "longitude" in self._ds:
121+
latitude = self._ds["latitude"]
122+
longitude = self._ds["longitude"]
123+
if latitude.dims == coords and longitude.dims == coords:
124+
latitude = latitude.data
125+
longitude = longitude.data
126+
points["y"] = latitude
127+
points["x"] = longitude
128+
129+
if not points:
130+
key = ("grid_points", tuple(keys), tuple(coords))
131+
132+
if key in self._cache:
133+
points = self._cache[key]
134+
else:
135+
v0, v1 = data_array.coords[coords[0]], data_array.coords[coords[1]]
136+
points[keys[1]], points[keys[0]] = np.meshgrid(v1, v0)
137+
self._cache[key] = points
138+
139+
if flatten:
140+
points["x"] = points["x"].reshape(-1)
141+
points["y"] = points["y"].reshape(-1)
142+
143+
if dtype is not None:
144+
return points["y"].astype(dtype), points["x"].astype(dtype)
145+
else:
146+
return points["y"], points["x"]

earthkit/data/readers/netcdf/field.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,20 +30,20 @@ def __init__(self, metadata, data_array, ds, variable):
3030
self.north, self.west, self.south, self.east = self.ds.bbox(variable)
3131

3232
def latitudes(self, dtype=None):
33-
return self.y(dtype=dtype)
33+
return self.ds._get_latlon(self.data_array, flatten=True, dtype=dtype)[0]
3434

3535
def longitudes(self, dtype=None):
36-
return self.x(dtype=dtype)
36+
return self.ds._get_latlon(self.data_array, flatten=True, dtype=dtype)[1]
3737

3838
def x(self, dtype=None):
39-
return self.ds._get_xy(self.data_array, "x", flatten=True, dtype=dtype)
39+
return self.ds._get_xy(self.data_array, flatten=True, dtype=dtype)[0]
4040

4141
def y(self, dtype=None):
42-
return self.ds._get_xy(self.data_array, "y", flatten=True, dtype=dtype)
42+
return self.ds._get_xy(self.data_array, flatten=True, dtype=dtype)[1]
4343

4444
def shape(self):
45-
coords = self.ds._get_xy_coords(self.data_array)
46-
return tuple([self.data_array.coords[v[1]].size for v in coords])
45+
_, coords = self.ds._get_xy_coords(self.data_array)
46+
return tuple([self.data_array.coords[v].size for v in coords])
4747

4848
def _unique_grid_id(self):
4949
return self.shape

tests/netcdf/test_netcdf_geography.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,78 @@ def test_netcdf_proj_string_laea():
167167
)
168168

169169

170+
def test_netcdf_to_points_laea():
171+
ds = from_source("url", earthkit_remote_test_data_file("examples", "efas.nc"))
172+
173+
assert len(ds) == 3
174+
175+
pos = [(0, 0), (0, -1), (-1, 0), (-1, -1)]
176+
177+
# we must check multiple fields
178+
for idx in range(2):
179+
v = ds[idx].to_points()
180+
assert isinstance(v, dict)
181+
182+
# lon
183+
assert isinstance(v["x"], np.ndarray)
184+
assert v["x"].shape == (950, 1000)
185+
186+
ref = np.array([2502500.0, 7497500.0, 2502500.0, 7497500.0])
187+
for i, x in enumerate(pos):
188+
assert np.isclose(v["x"][x], ref[i]), f"{i=}, {x=}"
189+
190+
# lat
191+
assert isinstance(v["y"], np.ndarray)
192+
assert v["y"].shape == (950, 1000)
193+
194+
ref = np.array([5497500.0, 5497500.0, 752500.0, 752500.0])
195+
for i, x in enumerate(pos):
196+
assert np.isclose(v["y"][x], ref[i]), f"{i=}, {x=}"
197+
198+
199+
def test_netcdf_to_latlon_laea():
200+
ds = from_source("url", earthkit_remote_test_data_file("examples", "efas.nc"))
201+
202+
assert len(ds) == 3
203+
204+
pos = [(0, 0), (0, -1), (-1, 0), (-1, -1)]
205+
206+
# we must check multiple fields
207+
for idx in range(2):
208+
v = ds[idx].to_latlon()
209+
assert isinstance(v, dict)
210+
211+
# lon
212+
assert isinstance(v["lon"], np.ndarray)
213+
assert v["lon"].shape == (950, 1000)
214+
215+
ref = np.array(
216+
[
217+
-35.034023999999995,
218+
73.93767587613708,
219+
-8.229274420493763,
220+
41.13970495087975,
221+
]
222+
)
223+
for i, x in enumerate(pos):
224+
assert np.isclose(v["lon"][x], ref[i]), f"{i=}, {x=}"
225+
226+
# lat
227+
assert isinstance(v["lat"], np.ndarray)
228+
assert v["lat"].shape == (950, 1000)
229+
230+
ref = np.array(
231+
[
232+
66.9821429989222,
233+
58.24673887576243,
234+
27.802844211251625,
235+
23.942342882929605,
236+
]
237+
)
238+
for i, x in enumerate(pos):
239+
assert np.isclose(v["lat"][x], ref[i]), f"{i=}, {x=}"
240+
241+
170242
if __name__ == "__main__":
171243
from earthkit.data.testing import main
172244

0 commit comments

Comments
 (0)