Skip to content

Commit 101ae17

Browse files
committed
feat: make dataclasses automatically de-serializable
1 parent 2dd9368 commit 101ae17

File tree

12 files changed

+37
-43
lines changed

12 files changed

+37
-43
lines changed

boiling_learning/automl/tuning.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,8 @@
66
from boiling_learning.datasets.datasets import DatasetTriplet
77
from boiling_learning.describe.described import Described
88
from boiling_learning.io import json
9-
from boiling_learning.io.storage import register_deserializer_for_dataclass
9+
from boiling_learning.io.storage import dataclass
1010
from boiling_learning.model.model import Evaluation, ModelArchitecture
11-
from boiling_learning.utils.dataclasses import dataclass
1211

1312

1413
@dataclass(frozen=True)
@@ -17,7 +16,6 @@ class TuneModelParams:
1716
batch_size: int
1817

1918

20-
@register_deserializer_for_dataclass
2119
@dataclass(frozen=True)
2220
class TuneModelReturn:
2321
model: ModelArchitecture

boiling_learning/daq/devices.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from nidaqmx.task import Task
22

3-
from boiling_learning.utils.dataclasses import dataclass
3+
from boiling_learning.io.storage import dataclass
44

55

66
@dataclass

boiling_learning/data/boiling_curve.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from pint import Quantity
44
from typing_extensions import Literal
55

6-
from boiling_learning.utils.dataclasses import dataclass
6+
from boiling_learning.io.storage import dataclass
77
from boiling_learning.utils.units import unit_registry as ureg
88

99
Q_ = ureg.Quantity

boiling_learning/datasets/datasets.py

+2-25
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,10 @@
11
from fractions import Fraction
2-
from pathlib import Path
3-
from typing import Any, Generic, Optional, TypeVar
2+
from typing import Generic, Optional, TypeVar
43

54
import funcy
65
from typing_extensions import NamedTuple
76

8-
from boiling_learning.io.storage import Metadata, deserialize, load, save, serialize
9-
from boiling_learning.utils.dataclasses import dataclass
10-
from boiling_learning.utils.pathutils import resolve
7+
from boiling_learning.io.storage import dataclass
118

129
_T = TypeVar('_T')
1310

@@ -57,23 +54,3 @@ def __post_init__(self) -> None:
5754

5855
if not (0 < self.train < 1 and 0 <= self.val < 1 and 0 < self.test < 1):
5956
raise ValueError('it is required that 0 < (*train*, *test*) < 1 and 0 <= *val* < 1')
60-
61-
62-
@serialize.instance(DatasetTriplet)
63-
def _serialize_dataset_triplet(instance: DatasetTriplet[Any], path: Path) -> None:
64-
path = resolve(path, dir=True)
65-
66-
ds_train, ds_val, ds_test = instance
67-
68-
save(ds_train, path / 'train')
69-
save(ds_val, path / 'val')
70-
save(ds_test, path / 'test')
71-
72-
73-
@deserialize.dispatch(DatasetTriplet)
74-
def _deserialize_dataset_triplet(path: Path, metadata: Metadata) -> DatasetTriplet[Any]:
75-
ds_train = load(path / 'train')
76-
ds_val = load(path / 'val')
77-
ds_test = load(path / 'test')
78-
79-
return DatasetTriplet(ds_train, ds_val, ds_test)

boiling_learning/io/storage.py

+18-1
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
from __future__ import annotations
22

33
import itertools
4+
from dataclasses import dataclass as _dataclass
45
from datetime import timedelta
6+
from functools import wraps
57
from pathlib import Path
6-
from typing import Any, Type, TypeVar
8+
from typing import Any, Callable, Type, TypeVar
79

810
from classes import AssociatedType, Supports, typeclass
911
from typing_extensions import final
@@ -15,6 +17,8 @@
1517

1618
Metadata = json.JSONDataType
1719
_DataClassType = TypeVar('_DataClassType', bound=Type[DataClass])
20+
_Any = TypeVar('_Any')
21+
_AnyCallable = TypeVar('_AnyCallable', bound=Callable[..., Any])
1822

1923

2024
@final
@@ -251,3 +255,16 @@ def _deserialize(path: Path, metadata: Metadata) -> DataClass:
251255
return dataclass_type(**fields)
252256

253257
return dataclass_type
258+
259+
260+
def _identity_compose(
261+
identity_transform: Callable[[_Any], _Any], function: _AnyCallable
262+
) -> _AnyCallable:
263+
@wraps(function)
264+
def _wrapped(*args, **kwargs) -> Any:
265+
return identity_transform(function(*args, **kwargs))
266+
267+
return _wrapped
268+
269+
270+
dataclass = _identity_compose(register_deserializer_for_dataclass, _dataclass)

boiling_learning/model/model.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,8 @@ def _deserialize_model(path: Path, _metadata: Metadata) -> ModelArchitecture:
100100
def model_memory_usage_in_bytes(
101101
architecture: ModelArchitecture, *, batch_size: int
102102
) -> Quantity[int]:
103-
"""
104-
Return the estimated memory usage of a given Keras model in bytes.
103+
"""Return the estimated memory usage of a given Keras model in bytes.
104+
105105
This includes the model weights and layers, but excludes the dataset.
106106
107107
The model shapes are multipled by the batch size, but the weights are not.
@@ -115,7 +115,6 @@ def model_memory_usage_in_bytes(
115115
pass `1` as the argument here.
116116
Returns:
117117
An estimate of the Keras model's memory usage in bytes.
118-
119118
"""
120119
model = architecture.model
121120

boiling_learning/model/training.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@
1616
from boiling_learning.describe.described import Described
1717
from boiling_learning.describe.describers import describe
1818
from boiling_learning.io import json
19-
from boiling_learning.io.storage import Metadata, deserialize, load, save, serialize
19+
from boiling_learning.io.storage import Metadata, dataclass, deserialize, load, save, serialize
2020
from boiling_learning.model.callbacks import RegisterEpoch, SaveHistory
2121
from boiling_learning.model.model import Evaluation, ModelArchitecture
22-
from boiling_learning.utils.dataclasses import dataclass, fields, shallow_asdict
22+
from boiling_learning.utils.dataclasses import fields, shallow_asdict
2323
from boiling_learning.utils.pathutils import resolve
2424
from boiling_learning.utils.timing import Timer
2525
from boiling_learning.utils.typeutils import typename
@@ -91,6 +91,8 @@ class FitModelReturn:
9191
evaluation: Evaluation
9292

9393

94+
# NOTE: after v0.34.23, dataclasses are automatically serializable and de-serializable.
95+
# however, we are keeping this overload for backwards compatibility.
9496
@serialize.instance(FitModelReturn)
9597
def _serialize_fit_model_return(instance: FitModelReturn, path: Path) -> None:
9698
path = resolve(path, dir=True)

boiling_learning/preprocessing/experiment_video.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@
1010
from boiling_learning.datasets.sliceable import SliceableDataset
1111
from boiling_learning.describe.describers import describe
1212
from boiling_learning.io import json
13+
from boiling_learning.io.storage import dataclass
1314
from boiling_learning.preprocessing.video import Video, VideoFrame, convert_video
14-
from boiling_learning.utils.dataclasses import dataclass, field
15+
from boiling_learning.utils.dataclasses import field
1516
from boiling_learning.utils.pathutils import PathLike, resolve
1617

1718

boiling_learning/preprocessing/image_datasets.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66

77
from boiling_learning.describe.describers import describe
88
from boiling_learning.io import json
9+
from boiling_learning.io.storage import dataclass
910
from boiling_learning.preprocessing.experiment_video import ExperimentVideo
1011
from boiling_learning.utils.collections import KeyedSet
11-
from boiling_learning.utils.dataclasses import dataclass, dataclass_from_mapping
12+
from boiling_learning.utils.dataclasses import dataclass_from_mapping
1213
from boiling_learning.utils.pathutils import PathLike, resolve
1314

1415

boiling_learning/utils/dataclasses.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
import typing
2-
from dataclasses import asdict, dataclass, field, fields, is_dataclass
2+
from dataclasses import asdict, field, fields, is_dataclass
33
from typing import Any, Callable, Dict, Mapping, Optional, Type, TypeVar, Union
44

55
from typing_extensions import TypeGuard
66

77
__all__ = (
88
'asdict',
9-
'dataclass',
109
'field',
1110
'fields',
1211
'is_dataclass',

main.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
)
4141
from boiling_learning.describe.described import Described
4242
from boiling_learning.io import json
43-
from boiling_learning.io.storage import load, save
43+
from boiling_learning.io.storage import dataclass, load, save
4444
from boiling_learning.management.allocators import default_table_allocator
4545
from boiling_learning.management.cacher import CachedFunction, Cacher
4646
from boiling_learning.model.callbacks import (
@@ -74,7 +74,6 @@
7474
set_condensation_datasets_data,
7575
)
7676
from boiling_learning.scripts.utils.initialization import check_all_paths_exist
77-
from boiling_learning.utils.dataclasses import dataclass
7877
from boiling_learning.utils.functional import P
7978
from boiling_learning.utils.lazy import Lazy, LazyCallable
8079
from boiling_learning.utils.pathutils import resolve

tests/test_utils/test_dataclasses.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
from dataclasses import dataclass
2+
13
from boiling_learning.utils.dataclasses import (
2-
dataclass,
34
is_dataclass,
45
is_dataclass_class,
56
is_dataclass_instance,

0 commit comments

Comments
 (0)