Skip to content

Group Backend Keyword Arguments #10422

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 113 additions & 40 deletions xarray/backends/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
MutableMapping,
Sequence,
)
from dataclasses import asdict, fields
from functools import partial
from io import BytesIO
from numbers import Number
Expand All @@ -29,8 +30,12 @@
from xarray.backends.common import (
AbstractDataStore,
ArrayWriter,
BackendOptions,
CoderOptions,
XarrayBackendOptions,
_find_absolute_paths,
_normalize_path,
_reset_dataclass_to_false,
)
from xarray.backends.locks import _get_scheduler
from xarray.coders import CFDatetimeCoder, CFTimedeltaCoder
Expand Down Expand Up @@ -382,19 +387,22 @@ def _dataset_from_backend_dataset(
backend_ds,
filename_or_obj,
engine,
chunks,
cache,
overwrite_encoded_chunks,
inline_array,
chunked_array_type,
from_array_kwargs,
coder_opts,
backend_opts,
**extra_tokens,
):
backend_kwargs = asdict(backend_opts)
chunks = backend_kwargs.get("chunks")
cache = backend_kwargs.get("cache")
if not isinstance(chunks, int | dict) and chunks not in {None, "auto"}:
raise ValueError(
f"chunks must be an int, dict, 'auto', or None. Instead found {chunks}."
)

coders_kwargs = asdict(coder_opts)
extra_tokens.update(**coders_kwargs)
extra_tokens.update(**backend_kwargs)

_protect_dataset_variables_inplace(backend_ds, cache)
if chunks is None:
ds = backend_ds
Expand All @@ -403,11 +411,6 @@ def _dataset_from_backend_dataset(
backend_ds,
filename_or_obj,
engine,
chunks,
overwrite_encoded_chunks,
inline_array,
chunked_array_type,
from_array_kwargs,
**extra_tokens,
)

Expand Down Expand Up @@ -476,6 +479,16 @@ def _datatree_from_backend_datatree(
return tree


# @dataclass(frozen=True)
# class XarrayBackendOptions:
# chunks: Optional[T_Chunks] = None
# cache: Optional[bool] = None
# inline_array: Optional[bool] = False
# chunked_array_type: Optional[str] = None
# from_array_kwargs: Optional[dict[str, Any]] = None
# overwrite_encoded_chunks: Optional[bool] = False


def open_dataset(
filename_or_obj: str | os.PathLike[Any] | ReadBuffer | AbstractDataStore,
*,
Expand All @@ -500,6 +513,10 @@ def open_dataset(
chunked_array_type: str | None = None,
from_array_kwargs: dict[str, Any] | None = None,
backend_kwargs: dict[str, Any] | None = None,
coder_opts: Union[bool, CoderOptions, None] = None,
open_opts: Union[bool, BackendOptions, None] = None,
backend_opts: Union[bool, BackendOptions, None] = None,
store_opts: Union[bool, BackendOptions, None] = None,
**kwargs,
) -> Dataset:
"""Open and decode a dataset from a file or file-like object.
Expand Down Expand Up @@ -672,36 +689,73 @@ def open_dataset(

backend = plugins.get_backend(engine)

decoders = _resolve_decoders_kwargs(
decode_cf,
open_backend_dataset_parameters=backend.open_dataset_parameters,
mask_and_scale=mask_and_scale,
decode_times=decode_times,
decode_timedelta=decode_timedelta,
concat_characters=concat_characters,
use_cftime=use_cftime,
decode_coords=decode_coords,
)
# initialize CoderOptions with decoders if not given
# Deprecation Fallback
if coder_opts is False:
coder_opts = _reset_dataclass_to_false(backend.coder_opts)
elif coder_opts is True:
coder_opts = backend.coder_opts
elif coder_opts is None:
field_names = {f.name for f in fields(backend.coder_class)}
decoders = _resolve_decoders_kwargs(
decode_cf,
open_backend_dataset_parameters=field_names,
mask_and_scale=mask_and_scale,
decode_times=decode_times,
decode_timedelta=decode_timedelta,
concat_characters=concat_characters,
use_cftime=use_cftime,
decode_coords=decode_coords,
)
decoders["drop_variables"] = drop_variables
coder_opts = backend.coder_class(**decoders)

if backend_opts is None:
backend_opts = XarrayBackendOptions(
chunks=chunks,
cache=cache,
inline_array=inline_array,
chunked_array_type=chunked_array_type,
from_array_kwargs=from_array_kwargs,
overwrite_encoded_chunks=kwargs.pop("overwrite_encoded_chunks", None),
)

# Check if store_opts have been overridden in the subclass.
# That indicates new-style behaviour.
# We can keep backwards compatibility.
_store_opts = backend.store_opts
if type(_store_opts) is BackendOptions:
coder_kwargs = asdict(coder_opts)
backend_ds = backend.open_dataset(
filename_or_obj,
**coder_kwargs,
**kwargs,
)
else:
if open_opts is None:
# check for open kwargs and create open_opts
field_names = {f.name for f in fields(backend.open_class)}
open_kwargs = {k: v for k, v in kwargs.items() if k in field_names}
open_opts = backend.open_class(**open_kwargs)
if store_opts is None:
# check for store kwargs and create store_opts
field_names = {f.name for f in fields(backend.store_class)}
store_kwargs = {k: v for k, v in kwargs.items() if k in field_names}
store_opts = backend.store_class(**store_kwargs)
backend_ds = backend.open_dataset(
filename_or_obj,
coder_opts=coder_opts,
open_opts=open_opts,
store_opts=store_opts,
**kwargs,
)

overwrite_encoded_chunks = kwargs.pop("overwrite_encoded_chunks", None)
backend_ds = backend.open_dataset(
filename_or_obj,
drop_variables=drop_variables,
**decoders,
**kwargs,
)
ds = _dataset_from_backend_dataset(
backend_ds,
filename_or_obj,
engine,
chunks,
cache,
overwrite_encoded_chunks,
inline_array,
chunked_array_type,
from_array_kwargs,
drop_variables=drop_variables,
**decoders,
coder_opts,
backend_opts,
**kwargs,
)
return ds
Expand Down Expand Up @@ -1838,6 +1892,9 @@ def to_netcdf(
multifile: bool = False,
invalid_netcdf: bool = False,
auto_complex: bool | None = None,
open_opts: Union[bool, BackendOptions, None] = None,
# backend_opts: Union[bool, BackendOptions, None] = None,
store_opts: Union[bool, BackendOptions, None] = None,
) -> tuple[ArrayWriter, AbstractDataStore] | bytes | Delayed | None:
"""This function creates an appropriate datastore for writing a dataset to
disk as a netCDF file
Expand Down Expand Up @@ -1879,12 +1936,10 @@ def to_netcdf(

try:
store_open = WRITEABLE_STORES[engine]
backend = plugins.get_backend(engine)
except KeyError as err:
raise ValueError(f"unrecognized engine for to_netcdf: {engine!r}") from err

if format is not None:
format = format.upper() # type: ignore[assignment]

# handle scheduler specific logic
scheduler = _get_scheduler()
have_chunks = any(v.chunks is not None for v in dataset.variables.values())
Expand All @@ -1908,7 +1963,25 @@ def to_netcdf(
if auto_complex is not None:
kwargs["auto_complex"] = auto_complex

store = store_open(target, mode, format, group, **kwargs)
if format is not None:
format = format.upper() # type: ignore[assignment]
kwargs["format"] = format
kwargs["group"] = group

# fill open_opts according backend
kwargs_names = list(kwargs)
field_names = {f.name for f in fields(backend.open_class)}
open_kwargs = {k: kwargs.pop(k) for k in kwargs_names if k in field_names}
open_opts = backend.open_class(**open_kwargs)

# fill store_opts according backend
field_names = {f.name for f in fields(backend.store_class)}
store_kwargs = {k: kwargs.pop(k) for k in kwargs_names if k in field_names}
store_opts = backend.store_class(**store_kwargs)

store = store_open(
target, mode=mode, open_opts=open_opts, store_opts=store_opts, **kwargs
)

if unlimited_dims is None:
unlimited_dims = dataset.encoding.get("unlimited_dims", None)
Expand Down
86 changes: 84 additions & 2 deletions xarray/backends/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,30 @@
import time
import traceback
from collections.abc import Hashable, Iterable, Mapping, Sequence
from dataclasses import dataclass, fields, replace
from glob import glob
from typing import TYPE_CHECKING, Any, ClassVar, TypeVar, Union, overload
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Literal,
Optional,
TypeVar,
Union,
overload,
)

import numpy as np
import pandas as pd

from xarray.backends.locks import SerializableLock
from xarray.coding import strings, variables
from xarray.coding.times import CFDatetimeCoder, CFTimedeltaCoder
from xarray.coding.variables import SerializationWarning
from xarray.conventions import cf_encoder
from xarray.core import indexing
from xarray.core.datatree import DataTree, Variable
from xarray.core.types import ReadBuffer
from xarray.core.types import ReadBuffer, T_Chunks
from xarray.core.utils import (
FrozenDict,
NdimSizeLenMixin,
Expand All @@ -41,6 +53,7 @@
NONE_VAR_NAME = "__values__"

T = TypeVar("T")
Buffer = Union[bytes, bytearray, memoryview]


@overload
Expand Down Expand Up @@ -646,6 +659,58 @@ def encode(self, variables, attributes):
return variables, attributes


def _reset_dataclass_to_false(instance):
field_names = [f.name for f in fields(instance)]
false_values = dict.fromkeys(field_names, False)
return replace(instance, **false_values)


@dataclass(frozen=True)
class BackendOptions:
pass


@dataclass(frozen=True)
class StoreWriteOptions:
group: Optional[str] = None
lock: Optional[SerializableLock] = None
autoclose: Optional[bool] = False


@dataclass(frozen=True)
class StoreWriteOpenOptions:
mode: Optional[str] = "r"
format: Optional[str] = "NETCDF4"


@dataclass(frozen=True)
class XarrayBackendOptions:
chunks: Optional[T_Chunks] = None
cache: Optional[bool] = None
inline_array: Optional[bool] = False
chunked_array_type: Optional[str] = None
from_array_kwargs: Optional[dict[str, Any]] = None
overwrite_encoded_chunks: Optional[bool] = False


@dataclass(frozen=True)
class CoderOptions:
# maybe add these two to disentangle masking from scaling?
# mask: Optional[bool] = None
# scale: Optional[bool] = None
mask_and_scale: Optional[bool | Mapping[str, bool]] = None
decode_times: Optional[
bool | CFDatetimeCoder | Mapping[str, bool | CFDatetimeCoder]
] = None
decode_timedelta: Optional[
bool | CFTimedeltaCoder | Mapping[str, bool | CFTimedeltaCoder]
] = None
use_cftime: Optional[bool | Mapping[str, bool]] = None
concat_characters: Optional[bool | Mapping[str, bool]] = None
decode_coords: Optional[Literal["coordinates", "all"] | bool] = None
drop_variables: Optional[str | Iterable[str]] = None


class BackendEntrypoint:
"""
``BackendEntrypoint`` is a class container and it is the main interface
Expand Down Expand Up @@ -683,6 +748,19 @@ class BackendEntrypoint:
open_dataset_parameters: ClassVar[tuple | None] = None
description: ClassVar[str] = ""
url: ClassVar[str] = ""
coder_class = CoderOptions
open_class = BackendOptions
store_class = BackendOptions

def __init__(
self,
coder_opts: Optional[BackendOptions] = None,
open_opts: Optional[BackendOptions] = None,
store_opts: Optional[BackendOptions] = None,
):
self.coder_opts = coder_opts if coder_opts is not None else self.coder_class()
self.open_opts = open_opts if open_opts is not None else self.open_class()
self.store_opts = store_opts if store_opts is not None else self.store_class()

def __repr__(self) -> str:
txt = f"<{type(self).__name__}>"
Expand All @@ -696,6 +774,10 @@ def open_dataset(
self,
filename_or_obj: str | os.PathLike[Any] | ReadBuffer | AbstractDataStore,
*,
coder_opts: Union[bool, CoderOptions, None] = None,
backend_opts: Union[bool, BackendOptions, None] = None,
open_opts: Union[bool, BackendOptions, None] = None,
store_opts: Union[bool, BackendOptions, None] = None,
drop_variables: str | Iterable[str] | None = None,
) -> Dataset:
"""
Expand Down
Loading
Loading