Skip to content

Commit 13e700c

Browse files
committed
refactor(model/training): remove custom (de)serialization for FitModelReturn
1 parent 101ae17 commit 13e700c

File tree

1 file changed

+1
-20
lines changed

1 file changed

+1
-20
lines changed

boiling_learning/model/training.py

+1-20
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from contextlib import contextmanager, nullcontext
44
from datetime import timedelta
5-
from pathlib import Path
65
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, TypeVar, Union
76

87
import tensorflow as tf
@@ -16,11 +15,9 @@
1615
from boiling_learning.describe.described import Described
1716
from boiling_learning.describe.describers import describe
1817
from boiling_learning.io import json
19-
from boiling_learning.io.storage import Metadata, dataclass, deserialize, load, save, serialize
18+
from boiling_learning.io.storage import dataclass, load
2019
from boiling_learning.model.callbacks import RegisterEpoch, SaveHistory
2120
from boiling_learning.model.model import Evaluation, ModelArchitecture
22-
from boiling_learning.utils.dataclasses import fields, shallow_asdict
23-
from boiling_learning.utils.pathutils import resolve
2421
from boiling_learning.utils.timing import Timer
2522
from boiling_learning.utils.typeutils import typename
2623

@@ -91,22 +88,6 @@ class FitModelReturn:
9188
evaluation: Evaluation
9289

9390

94-
# NOTE: after v0.34.23, dataclasses are automatically serializable and de-serializable.
95-
# however, we are keeping this overload for backwards compatibility.
96-
@serialize.instance(FitModelReturn)
97-
def _serialize_fit_model_return(instance: FitModelReturn, path: Path) -> None:
98-
path = resolve(path, dir=True)
99-
for field_name, field in shallow_asdict(instance).items():
100-
save(field, path / field_name)
101-
102-
103-
@deserialize.dispatch(FitModelReturn)
104-
def _deserialize_fit_model_return(path: Path, _metadata: Metadata) -> FitModelReturn:
105-
return FitModelReturn(
106-
**{field.name: load(path / field.name) for field in fields(FitModelReturn)}
107-
)
108-
109-
11091
def get_fit_model(
11192
compiled_model: CompiledModel,
11293
datasets: Described[DatasetTriplet[tf.data.Dataset], json.JSONDataType],

0 commit comments

Comments
 (0)