Skip to content

Commit 4699084

Browse files
Fix memory leak in Xarray computations using chunking (#694)
1 parent 718651e commit 4699084

File tree

5 files changed

+80
-44
lines changed

5 files changed

+80
-44
lines changed

docs/release_notes/version_0.14_updates.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,14 @@ Version 0.14 Updates
22
/////////////////////////
33

44

5+
Version 0.14.1
6+
===============
7+
8+
Fixes
9+
+++++++++++++++++
10+
11+
- Fixed issue when Xarray computations used excessive memory when the data was created with the Xarray engine using chunking (:pr:`694`).
12+
513
Version 0.14.0
614
===============
715

src/earthkit/data/readers/grib/virtual.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import logging
1111
from functools import cached_property
1212

13-
from earthkit.data import from_source
1413
from earthkit.data.core.fieldlist import Field
1514
from earthkit.data.core.metadata import WrappedMetadata
1615
from earthkit.data.utils.dates import date_to_grib
@@ -120,22 +119,19 @@ def _metadata(self):
120119
def _values(self, dtype=None):
121120
return self._field._values(dtype=dtype)
122121

123-
@cached_property
122+
@property
124123
def _field(self):
125124
if self.reference:
126125
return self.reference
127126
else:
128-
p = self.owner.retriever.get(self.request)
129-
return from_source("file", p, stream=True, read_all=True)[0]
127+
return self.owner.retriever.get(self.request)[0]
130128

131129

132130
class VirtualGribFieldList(GribFieldList):
133131
def __init__(self, request_mapper, retriever):
134132
self.request_mapper = request_mapper
135133
self.retriever = retriever
136134

137-
path = self.retriever.get(self.request_mapper.request_at(0))
138-
self.reference = from_source("file", path)[0]
139135
self._info_cache = {}
140136

141137
def __len__(self):
@@ -144,6 +140,10 @@ def __len__(self):
144140
def mutate(self):
145141
return self
146142

143+
@cached_property
144+
def reference(self):
145+
return self.retriever.get(self.request_mapper.request_at(0))[0]
146+
147147
def _getitem(self, n):
148148
if isinstance(n, int):
149149
if n < 0:

src/earthkit/data/sources/fdb.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,9 @@ def __init__(self, fdb_kwargs):
103103
self.fdb_kwargs = fdb_kwargs
104104

105105
def get(self, request):
106-
fdb = pyfdb.FDB(**self.fdb_kwargs)
107-
s = FDBFileSource(fdb, request)
108-
return s.path
106+
from . import from_source
107+
108+
return from_source("fdb", request, stream=True, read_all=True, **self.fdb_kwargs)
109109

110110

111111
class FDBRequestMapper(RequestMapper):

src/earthkit/data/utils/xarray/builder.py

Lines changed: 33 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#
99

1010
import logging
11+
import threading
1112
from abc import ABCMeta
1213
from abc import abstractmethod
1314

@@ -176,8 +177,7 @@ def __init__(self, tensor, dims, shape, xp, dtype, variable):
176177
if dtype is None:
177178
dtype = numpy.dtype("float64")
178179
self.dtype = xp.dtype(dtype)
179-
180-
self._var = variable
180+
self.lock = threading.Lock()
181181

182182
@property
183183
def nbytes(self):
@@ -205,39 +205,37 @@ def __getitem__(self, key: xarray.core.indexing.ExplicitIndexer):
205205
# patched in local copy for now, but could construct this ourself
206206

207207
def _raw_indexing_method(self, key: tuple):
208-
# must be threadsafe
209-
# print("_var", self._var)
210-
# print(f"dims: {self.dims} key: {key} shape: {self.shape}")
211-
# isels = dict(zip(self.dims, key))
212-
# r = self.ekds.isel(**isels)
213-
# print(f"t-coords={self.tensor.user_coords}")
214-
r = self.tensor[key]
215-
# print(r.source.ls())
216-
# print(f"r-shape: {r.user_shape}")
217-
218-
field_index = r.field_indexes(key)
219-
# print(f"field.index={field_index} coords={r.user_coords}")
220-
# result = r.to_numpy(index=field_index).squeeze()
221-
result = r.to_numpy(index=field_index, dtype=self.dtype)
222-
223-
# ensure axes are squeezed when needed
224-
singles = [i for i in list(range(len(r.user_shape))) if isinstance(key[i], int)]
225-
if singles:
226-
result = result.squeeze(axis=tuple(singles))
227-
228-
# print("result", result.shape)
229-
# result = self.ekds.isel(**isels).to_numpy()
230-
231-
# print("result", result.shape)
232-
# print(f"Loaded {self.xp.__name__} with shape: {result.shape}")
233-
234-
# Loading as numpy but then converting. This needs to be changed upstream (eccodes)
235-
# to load directly into cupy.
236-
# Maybe some incompatibilities when trying to copy from FFI to cupy directly
237-
if self.xp and self.xp != numpy:
238-
result = self.xp.asarray(result)
239-
240-
return result
208+
with self.lock:
209+
# print("_var", self._var)
210+
# print(f"dims: {self.dims} key: {key} shape: {self.shape}")
211+
# print(f"t-coords={self.tensor.user_coords}")
212+
r = self.tensor[key]
213+
# print(r.source.ls())
214+
# print(f"r-shape: {r.user_shape}")
215+
216+
field_index = r.field_indexes(key)
217+
# print(f"field.index={field_index} coords={r.user_coords}")
218+
# result = r.to_numpy(index=field_index).squeeze()
219+
result = r.to_numpy(index=field_index, dtype=self.dtype)
220+
221+
# ensure axes are squeezed when needed
222+
singles = [i for i in list(range(len(r.user_shape))) if isinstance(key[i], int)]
223+
if singles:
224+
result = result.squeeze(axis=tuple(singles))
225+
226+
# print("result", result.shape)
227+
# result = self.ekds.isel(**isels).to_numpy()
228+
229+
# print("result", result.shape)
230+
# print(f"Loaded {self.xp.__name__} with shape: {result.shape}")
231+
232+
# Loading as numpy but then converting. This needs to be changed upstream (eccodes)
233+
# to load directly into cupy.
234+
# Maybe some incompatibilities when trying to copy from FFI to cupy directly
235+
if self.xp and self.xp != numpy:
236+
result = self.xp.asarray(result)
237+
238+
return result
241239

242240

243241
class BackendDataBuilder(metaclass=ABCMeta):

tests/xr_engine/test_xr_chunks.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,33 @@ def test_xr_engine_chunk_2(_kwargs):
9494
r = ds["2t"].mean("valid_time").load()
9595

9696
assert np.isclose(r.values.mean(), 275.9938876277779)
97+
98+
99+
@pytest.mark.skipif(True, reason="Needs to be fixed")
100+
@pytest.mark.cache
101+
@pytest.mark.parametrize(
102+
"_kwargs",
103+
[
104+
{},
105+
{"chunks": "auto"},
106+
{"chunks": {"valid_time": 1}},
107+
{"chunks": {"valid_time": 10}},
108+
{"chunks": {"valid_time": (100, 200, 432), "latitude": (4, 5, 4), "longitude": (13, 3, 8)}},
109+
{"chunks": -1},
110+
],
111+
)
112+
def test_xr_engine_chunk_3(_kwargs):
113+
# in-memory fieldlist
114+
ds_in = from_source(
115+
"url",
116+
earthkit_remote_test_data_file("test-data", "xr_engine", "date", "t2_1_year_hourly.grib"),
117+
stream=True,
118+
read_all=True,
119+
)
120+
121+
ds = ds_in.to_xarray(time_dim_mode="valid_time", **_kwargs)
122+
assert ds is not None
123+
124+
r = ds["2t"].mean("valid_time").load()
125+
126+
assert np.isclose(r.values.mean(), 275.9938876277779)

0 commit comments

Comments
 (0)