Skip to content

Commit 91846de

Browse files
Split encoding from writing in grib output (#409)
* Split encoding from writing in grib output
1 parent 4f7ad4d commit 91846de

File tree

1 file changed

+76
-35
lines changed

1 file changed

+76
-35
lines changed

src/earthkit/data/readers/grib/output.py

Lines changed: 76 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,14 @@
1010
import datetime
1111
import logging
1212
import re
13+
from io import IOBase
1314

1415
from earthkit.data.decorators import normalize
1516
from earthkit.data.decorators import normalize_grib_keys
1617
from earthkit.data.utils.humanize import list_to_human
1718

1819
LOG = logging.getLogger(__name__)
1920

20-
# Make sure the
21-
2221
ACCUMULATIONS = {("tp", 2): {"productDefinitionTemplateNumber": 8}}
2322

2423
_ORDER = (
@@ -59,15 +58,8 @@ def __getitem__(self, key):
5958
return self.handle.get(key, default=None)
6059

6160

62-
class GribOutput:
63-
def __init__(self, filename, split_output=False, template=None, **kwargs):
64-
self._files = {}
65-
self.filename = filename
66-
67-
if split_output:
68-
self.split_output = re.findall(r"\{(.*?)\}", self.filename)
69-
else:
70-
self.split_output = None
61+
class GribCoder:
62+
def __init__(self, template=None, **kwargs):
7163

7264
self.template = template
7365
self._bbox = {}
@@ -78,22 +70,13 @@ def __init__(self, filename, split_output=False, template=None, **kwargs):
7870
def _normalize_kwargs_names(self, **kwargs):
7971
return kwargs
8072

81-
def f(self, handle):
82-
if self.split_output:
83-
path = self.filename.format(**{k: handle.get(k) for k in self.split_output})
84-
else:
85-
path = self.filename
86-
87-
if path not in self._files:
88-
self._files[path] = open(path, "wb")
89-
return self._files[path], path
90-
91-
def write(
73+
def encode(
9274
self,
9375
values,
9476
check_nans=False,
9577
metadata={},
9678
template=None,
79+
return_bytes=False,
9780
**kwargs,
9881
):
9982
# Make a copy as we may modify it
@@ -148,20 +131,10 @@ def write(
148131
if values is not None:
149132
handle.set_values(values)
150133

151-
file, path = self.f(handle)
152-
handle.write(file)
134+
if return_bytes:
135+
return handle.get_message()
153136

154-
return handle, path
155-
156-
def close(self):
157-
for f in self._files.values():
158-
f.close()
159-
160-
def __enter__(self):
161-
return self
162-
163-
def __exit__(self, exc_type, exc_value, trace):
164-
self.close()
137+
return handle
165138

166139
def update_metadata(self, handle, metadata, compulsory):
167140
# TODO: revisit that logic
@@ -376,5 +349,73 @@ def _gg_field(self, values, metadata):
376349
return f"reduced_gg_{levtype}_{N}_grib{edition}"
377350

378351

352+
class GribOutput:
353+
def __init__(self, file, split_output=False, template=None, **kwargs):
354+
self._files = {}
355+
self.fileobj = None
356+
self.filename = None
357+
358+
if isinstance(file, IOBase):
359+
self.fileobj = file
360+
split_output = False
361+
else:
362+
self.filename = file
363+
364+
if split_output:
365+
self.split_output = re.findall(r"\{(.*?)\}", self.filename)
366+
else:
367+
self.split_output = None
368+
369+
self._coder = GribCoder(template=template, **kwargs)
370+
371+
def close(self):
372+
for f in self._files.values():
373+
f.close()
374+
375+
def __enter__(self):
376+
return self
377+
378+
def __exit__(self, exc_type, exc_value, trace):
379+
self.close()
380+
381+
def write(
382+
self,
383+
values,
384+
check_nans=False,
385+
metadata={},
386+
template=None,
387+
**kwargs,
388+
):
389+
handle = self._coder.encode(
390+
values,
391+
check_nans=check_nans,
392+
metadata=metadata,
393+
template=template,
394+
**kwargs,
395+
)
396+
397+
file, path = self.f(handle)
398+
handle.write(file)
399+
400+
return handle, path
401+
402+
def f(self, handle):
403+
if self.fileobj:
404+
return self.fileobj, None
405+
406+
if self.split_output:
407+
path = self.filename.format(**{k: handle.get(k) for k in self.split_output})
408+
else:
409+
path = self.filename
410+
411+
if path not in self._files:
412+
self._files[path] = open(path, "wb")
413+
return self._files[path], path
414+
415+
379416
def new_grib_output(*args, **kwargs):
380417
return GribOutput(*args, **kwargs)
418+
419+
420+
def new_grib_coder(*args, **kwargs):
421+
return GribCoder(*args, **kwargs)

0 commit comments

Comments
 (0)