Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Output Struct Overhaul #445

Open
wants to merge 18 commits into
base: v4-prep
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/py21cmfast/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
from ._cfg import config
from ._data import DATA_PATH
from ._logging import configure_logging
from .drivers.coeval import Coeval, run_coeval
from .drivers.lightcone import LightCone, exhaust_lightcone, run_lightcone
from .drivers.coeval import Coeval, generate_coeval, run_coeval
from .drivers.lightcone import LightCone, generate_lightcone, run_lightcone
from .drivers.single_field import (
brightness_temperature,
compute_halo_grid,
Expand Down
11 changes: 5 additions & 6 deletions src/py21cmfast/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from . import _cfg, global_params, plotting
from .drivers.coeval import run_coeval
from .drivers.lightcone import exhaust_lightcone, run_lightcone
from .drivers.lightcone import run_lightcone
from .drivers.single_field import (
compute_initial_conditions,
compute_ionization_field,
Expand Down Expand Up @@ -492,7 +492,7 @@ def ionize(ctx, redshift, prev_z, config, regen, direc, seed):
help="Whether to force regeneration of init/perturb files if they already exist.",
)
@click.option(
"--direc",
"--cache-dir",
type=click.Path(exists=True, dir_okay=True),
default=None,
help="cache directory",
Expand All @@ -504,7 +504,7 @@ def ionize(ctx, redshift, prev_z, config, regen, direc, seed):
help="specify a random seed for the initial conditions",
)
@click.pass_context
def coeval(ctx, redshift, config, out, regen, direc, seed):
def coeval(ctx, redshift, config, out, regen, cache_dir, seed):
"""Efficiently generate coeval cubes at a given redshift.

Parameters
Expand Down Expand Up @@ -554,8 +554,7 @@ def coeval(ctx, redshift, config, out, regen, direc, seed):
inputs=inputs,
regenerate=regen,
write=True,
direc=direc,
random_seed=seed,
cache=OutputCache(cache_dir) if cache_dir else None,
)

if out:
Expand Down Expand Up @@ -678,7 +677,7 @@ def lightcone(ctx, redshift, config, out, regen, direc, max_z, seed, lq):
quantities=lq,
)

lc = exhaust_lightcone(
lc = run_lightcone(
lightconer=lcn,
inputs=inputs,
regenerate=regen,
Expand Down
54 changes: 44 additions & 10 deletions src/py21cmfast/drivers/_param_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,23 @@ def __init__(self, _func: callable):
f"{_func.__name__} must return an instance of OutputStruct (and be annotated as such)."
)

@staticmethod
def _get_all_output_struct_inputs(
kwargs, recurse: bool = False
) -> dict[str, OutputStruct]:
"""Return all the arguments that are OutputStructs.

If recurse is True, also add all OutputStructs that are part of iterables.
"""
d = {k: v for k, v in kwargs.items() if isinstance(v, OutputStruct)}

if recurse:
for k, v in kwargs.items():
if hasattr(v, "__len__") and isinstance(v[0], OutputStruct):
d |= {f"{k}_{i}": vv for i, vv in enumerate(v)}

return d

@staticmethod
def _get_inputs(kwargs: dict[str, Any]) -> InputParameters:
"""Return the most detailed input parameters available.
Expand All @@ -180,26 +197,29 @@ def _get_inputs(kwargs: dict[str, Any]) -> InputParameters:
situation that different dependent OutputStruct's have different inputs. Even
though all must be compatible with each other, more basic OutputStructs (like
InitialConditions) might not have the same zgrid as the PerturbedField (for
example) and this fine. So, here we return the inputs of the "most advanced"
example) and this is fine. So, here we return the inputs of the "most advanced"
OutputStruct that is given.
"""
inputs = kwargs.get("inputs")
if inputs is not None:
return inputs

outputs = single_field_func._get_all_output_struct_inputs(kwargs)
outputs = _OutputStructComputationInspect._get_all_output_struct_inputs(
kwargs, recurse=True
)

minreq = _HashType(0)
for output in outputs.values():
if output._compat_hash.value >= minreq.value:
inputs = output.inputs
minreq = output._compat_hash

return inputs
if inputs is None:
raise ValueError(
"No parameter 'inputs' given, and no dependent OutputStruct found!"
)

@staticmethod
def _get_all_output_struct_inputs(kwargs):
return {k: v for k, v in kwargs.items() if isinstance(v, OutputStruct)}
return inputs

@staticmethod
def check_consistency(kwargs: dict[str, Any], outputs: dict[str, OutputStruct]):
Expand Down Expand Up @@ -384,12 +404,13 @@ def __call__(self, **kwargs) -> OutputStruct:
"""Call the single field function."""
inputs = self._get_inputs(kwargs)
outputs = self._get_all_output_struct_inputs(kwargs)
outputs_rec = self._get_all_output_struct_inputs(kwargs, recurse=True)
outputsz = {k: v for k, v in outputs.items() if isinstance(v, OutputStructZ)}

# Get current redshift (could be None)
current_redshift = self._get_current_redshift(outputsz, kwargs)

self.check_consistency(kwargs, outputs)
self.check_consistency(kwargs, outputs_rec)
self.check_output_struct_types(outputs)
# The following checks both current and previous redshifts, if applicable
self.ensure_redshift_consistency(current_redshift, outputsz)
Expand All @@ -416,14 +437,27 @@ def __call__(self, **kwargs) -> OutputStruct:
return out


class high_level_func: # noqa: N801
class high_level_func(_OutputStructComputationInspect): # noqa: N801
"""A decorator for high-level functions like ``run_coeval``."""

# def __init__(self, _func: callable):
# self._func = _func
def __init__(self, _func: callable):
self._func = _func
self._signature = inspect.signature(_func)
self._kls = self._signature.return_annotation

def __call__(self, **kwargs):
"""Call the function."""
outputs = _OutputStructComputationInspect._get_all_output_struct_inputs(kwargs)
_OutputStructComputationInspect.check_consistency(kwargs, outputs)
outputs = self._get_all_output_struct_inputs(kwargs, recurse=True)
inputs = self._get_inputs(kwargs)
if "inputs" in self._signature.parameters:
# Here we set the inputs (if accepted by the function signature)
# to the most advanced ones. This is the explicitly-passed inputs if
# they exist, but otherwise the inputs derived from the dependency
# that is the most advanced in the computation.
kwargs["inputs"] = inputs

self.check_consistency(kwargs, outputs)

yield from self._func(**kwargs)
Loading
Loading