|
1 | 1 | from fractions import Fraction
|
2 |
| -from pathlib import Path |
3 |
| -from typing import Any, Generic, Optional, TypeVar |
| 2 | +from typing import Generic, Optional, TypeVar |
4 | 3 |
|
5 | 4 | import funcy
|
6 | 5 | from typing_extensions import NamedTuple
|
7 | 6 |
|
8 |
| -from boiling_learning.io.storage import Metadata, deserialize, load, save, serialize |
9 |
| -from boiling_learning.utils.dataclasses import dataclass |
10 |
| -from boiling_learning.utils.pathutils import resolve |
| 7 | +from boiling_learning.io.storage import dataclass |
11 | 8 |
|
12 | 9 | _T = TypeVar('_T')
|
13 | 10 |
|
@@ -57,23 +54,3 @@ def __post_init__(self) -> None:
|
57 | 54 |
|
58 | 55 | if not (0 < self.train < 1 and 0 <= self.val < 1 and 0 < self.test < 1):
|
59 | 56 | raise ValueError('it is required that 0 < (*train*, *test*) < 1 and 0 <= *val* < 1')
|
60 |
| - |
61 |
| - |
62 |
| -@serialize.instance(DatasetTriplet) |
63 |
| -def _serialize_dataset_triplet(instance: DatasetTriplet[Any], path: Path) -> None: |
64 |
| - path = resolve(path, dir=True) |
65 |
| - |
66 |
| - ds_train, ds_val, ds_test = instance |
67 |
| - |
68 |
| - save(ds_train, path / 'train') |
69 |
| - save(ds_val, path / 'val') |
70 |
| - save(ds_test, path / 'test') |
71 |
| - |
72 |
| - |
73 |
| -@deserialize.dispatch(DatasetTriplet) |
74 |
| -def _deserialize_dataset_triplet(path: Path, metadata: Metadata) -> DatasetTriplet[Any]: |
75 |
| - ds_train = load(path / 'train') |
76 |
| - ds_val = load(path / 'val') |
77 |
| - ds_test = load(path / 'test') |
78 |
| - |
79 |
| - return DatasetTriplet(ds_train, ds_val, ds_test) |
0 commit comments