Skip to content

Commit

Permalink
Adds FilteredMetaDataset and fixes length of UnionMetaDataset. (#195)
Browse files Browse the repository at this point in the history
* Fix and test length of UnionMetaDataset.

* Initial implementation of FilteredMetaDataset.

* Add test for equal classes and samples.

* Update changelog.
  • Loading branch information
seba-1511 authored Oct 21, 2020
1 parent 182b9ab commit 76fb944
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 3 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

* `UnionMetaDatasest` to getthe union of multiple MetaDatasets.
* `FilteredMetaDatasest` filter the classes used to sample tasks.
* `UnionMetaDatasest` to get the union of multiple MetaDatasets.
* Alias `MiniImageNetCNN` to `CNN4` and add `embedding_size` argument.
* Optional data augmentation schemes for vision benchmarks.
* `l2l.vision.models.ResNet12`
Expand Down
1 change: 1 addition & 0 deletions docs/pydocmd.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ generate:
- learn2learn.data:
- learn2learn.data.MetaDataset
- learn2learn.data.UnionMetaDataset
- learn2learn.data.FilteredMetaDataset
- learn2learn.data.TaskDataset
- learn2learn.data.transforms:
- learn2learn.data.transforms.LoadData
Expand Down
2 changes: 1 addition & 1 deletion learn2learn/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@
"""

from . import transforms
from .meta_dataset import MetaDataset, UnionMetaDataset
from .meta_dataset import MetaDataset, UnionMetaDataset, FilteredMetaDataset
from .task_dataset import TaskDataset, DataDescription
54 changes: 53 additions & 1 deletion learn2learn/data/meta_dataset.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -191,4 +191,56 @@ class UnionMetaDataset(MetaDataset):
ds_count += len(dataset)

def __len__(self):
return len(self.labels_to_indices)
return len(self.indices_to_labels)


class FilteredMetaDataset(MetaDataset):

"""
**Description**
Takes in a MetaDataset and filters it to only include a subset of its labels.
Note: The labels of all datasets are **not** remapped.
(i.e. the labels from the original dataset are retained)
**Arguments**
* **datasets** (Dataset) - A torch Datasets.
* **labels** (list of ints) - A list of labels to keep.
**Example**
~~~python
train = torchvision.datasets.CIFARFS(root="/tmp/mnist", mode="train")
train = l2l.data.MetaDataset(train)
filtered = FilteredMetaDataset(train, [4, 8, 2, 1, 9])
assert len(filtered.labels) == 5
~~~
"""

def __init__(self, dataset, labels):
if not isinstance(dataset, MetaDataset):
dataset = MetaDataset(dataset)
self.dataset = dataset
self.to_true_indices = []
labels_to_indices = defaultdict(list)
indices_to_labels = defaultdict(int)
idx_count = 0
for label in labels:
for true_idx in dataset.labels_to_indices[label]:
self.to_true_indices.append(true_idx)
labels_to_indices[label].append(idx_count)
indices_to_labels[idx_count] = dataset.indices_to_labels[true_idx]
idx_count += 1

self.labels_to_indices = labels_to_indices
self.indices_to_labels = indices_to_labels
self.labels = list(self.labels_to_indices.keys())

def __getitem__(self, item):
true_idx = self.to_true_indices[item]
return self.dataset[true_idx]

def __len__(self):
return len(self.indices_to_labels)
24 changes: 24 additions & 0 deletions tests/unit/data/metadataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def test_union_metadataset(self):
]
datasets = [l2l.data.MetaDataset(ds) for ds in datasets]
union = l2l.data.UnionMetaDataset(datasets)
self.assertEqual(len(union), sum([len(ds) for ds in datasets]))
self.assertTrue(len(union.labels) == sum([len(ds.labels) for ds in datasets]))
self.assertTrue(len(union.indices_to_labels) == sum([len(ds.indices_to_labels) for ds in datasets]))
ref = datasets[1][23]
Expand All @@ -87,6 +88,29 @@ def test_union_metadataset(self):
# self.assertTrue(item[1] == ref[1]) # Would fail, because labels are remapped.
self.assertTrue(np.linalg.norm(np.array(item[0]) - np.array(ref[0])) <= 1e-6)

def test_filtered_metadataset(self):
for ds_class in [
l2l.vision.datasets.FC100,
l2l.vision.datasets.CIFARFS,
]:
datasets = [
ds_class('~/data', mode='train', download=True),
ds_class('~/data', mode='validation', download=True),
ds_class('~/data', mode='test', download=True),
]
datasets = [l2l.data.MetaDataset(ds) for ds in datasets]
union = l2l.data.UnionMetaDataset(datasets)
classes = datasets[1].labels
filtered = l2l.data.FilteredMetaDataset(union, classes)
self.assertEqual(len(filtered.labels), len(datasets[1].labels))
self.assertEqual(len(filtered), len(datasets[1]))
for label in filtered.labels:
self.assertTrue(label in datasets[1].labels)
self.assertEqual(
len(filtered.labels_to_indices[label]),
len(datasets[1].labels_to_indices[label])
)


if __name__ == '__main__':
unittest.main()

0 comments on commit 76fb944

Please sign in to comment.