From 63ff92e61b4e65760c02b20b918f6bea18b4ed8f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9b=20Arnold?= Date: Tue, 7 Jul 2020 23:02:54 -0400 Subject: [PATCH] Add differentiable optimization module. (Meta-Descent, KFO, Meta-Curvature) (#151) * Ported hypergrad example. * Add meta-curvature example with GBML wrapper. * GBML support for nograd, unused, first_order and tests. * Add ANIL+KFO low-level example. * Add misc nn layers. * Update maml_update. * Change download path for mini-imagenet tests. * Add docs for differentiable sgd. * Update docs, incl. for MetaWorld. * KroneckerTranform docs. * Docs for meta-curvature. * Add docs for l2l.nn.misc. * Add docs for kroneckers. * Fix lint, add more docs. * Add docs for GBML. * Completes GBML docs. * Rename meta_update -> update_module, and write docs. * Fix lint, add docs for ParameterUpdate. * Add docs for LearnableOptimizer * Update changelog * Update to readme, part 1 * Update README, part 2. * Fix readme links * Version bump. --- CHANGELOG.md | 15 +- README.md | 173 ++++++++----- docs/pydocmd.yml | 34 ++- examples/optimization/README.md | 22 ++ examples/optimization/hypergrad_mnist.py | 135 ++++++++++ examples/vision/anilkfo_cifarfs.py | 232 ++++++++++++++++++ examples/vision/maml_miniimagenet.py | 14 +- examples/vision/metacurvature_fc100.py | 190 ++++++++++++++ learn2learn/__init__.py | 4 +- learn2learn/_version.py | 2 +- learn2learn/algorithms/__init__.py | 5 + learn2learn/algorithms/gbml.py | 202 +++++++++++++++ learn2learn/algorithms/maml.py | 31 +-- learn2learn/data/__init__.py | 4 + learn2learn/gym/__init__.py | 4 + learn2learn/gym/envs/metaworld/metaworld.py | 11 +- learn2learn/nn/__init__.py | 8 + learn2learn/nn/kroneckers.py | 196 +++++++++++++++ learn2learn/nn/misc.py | 86 +++++++ learn2learn/optim/__init__.py | 10 + learn2learn/optim/learnable_optimizer.py | 107 ++++++++ learn2learn/optim/parameter_update.py | 138 +++++++++++ learn2learn/optim/transforms/__init__.py | 13 + .../optim/transforms/kronecker_transform.py | 75 ++++++ .../transforms/metacurvature_transform.py | 82 +++++++ .../optim/transforms/module_transform.py | 66 +++++ .../optim/transforms/transform_dictionary.py | 25 ++ learn2learn/optim/update_rules/__init__.py | 3 + .../optim/update_rules/differentiable_sgd.py | 57 +++++ learn2learn/utils.py | 79 +++++- learn2learn/vision/__init__.py | 4 + .../maml_miniimagenet_test_notravis.py | 6 +- .../protonets_miniimagenet_test_notravis.py | 2 +- tests/unit/algorithms/gbml_test.py | 151 ++++++++++++ tests/unit/data/util_datasets.py | 2 + tests/unit/vision/benchmarks_test.py | 2 +- 36 files changed, 2079 insertions(+), 111 deletions(-) create mode 100644 examples/optimization/README.md create mode 100644 examples/optimization/hypergrad_mnist.py create mode 100644 examples/vision/anilkfo_cifarfs.py create mode 100644 examples/vision/metacurvature_fc100.py create mode 100644 learn2learn/algorithms/gbml.py create mode 100644 learn2learn/nn/__init__.py create mode 100644 learn2learn/nn/kroneckers.py create mode 100644 learn2learn/nn/misc.py create mode 100644 learn2learn/optim/__init__.py create mode 100644 learn2learn/optim/learnable_optimizer.py create mode 100644 learn2learn/optim/parameter_update.py create mode 100644 learn2learn/optim/transforms/__init__.py create mode 100644 learn2learn/optim/transforms/kronecker_transform.py create mode 100644 learn2learn/optim/transforms/metacurvature_transform.py create mode 100644 learn2learn/optim/transforms/module_transform.py create mode 100644 learn2learn/optim/transforms/transform_dictionary.py create mode 100644 learn2learn/optim/update_rules/__init__.py create mode 100644 learn2learn/optim/update_rules/differentiable_sgd.py create mode 100644 tests/unit/algorithms/gbml_test.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 9f163ce8..e313f0b1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,8 +10,21 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +### Changed + +### Fixed + + +## v0.1.2 + +### Added + * New example: [Meta-World](https://github.com/rlworkgroup/metaworld) example with MAML-TRPO with it's own env wrapper. (@[Kostis-S-Z](https://github.com/Kostis-S-Z)) -* Add l2l.vision.benchmarks interface. +* `l2l.vision.benchmarks` interface. +* Differentiable optimization utilities in `l2l.optim`. (including `l2l.optim.LearnableOptimizer` for meta-descent) +* General gradient-based meta-learning wrapper in `l2l.algorithms.GBML`. +* Various `nn.Modules` in `l2l.nn`. +* `l2l.update_module` as a more general alternative to `l2l.algorithms.maml_update`. ### Changed diff --git a/README.md b/README.md index 70f81e8e..d1f6a618 100644 --- a/README.md +++ b/README.md @@ -4,27 +4,32 @@ [![Build Status](https://travis-ci.com/learnables/learn2learn.svg?branch=master)](https://travis-ci.com/learnables/learn2learn) -learn2learn is a PyTorch library for meta-learning implementations. +learn2learn is a software library for meta-learning research. -The goal of meta-learning is to enable agents to *learn how to learn*. -That is, we would like our agents to become better learners as they solve more and more tasks. -For example, the animation below shows an agent that learns to run after a only one parameter update. +learn2learn builds on top of PyTorch to accelerate two aspects of the meta-learning research cycle: -

+* *fast prototyping*, essential in letting researchers quickly try new ideas, and +* *correct reproducibility*, ensuring that these ideas are evaluated fairly. -**Features** +learn2learn provides low-level utilities and unified interface to create new algorithms and domains, together with high-quality implementations of existing algorithms and standardized benchmarks. +It retains compatibility with [torchvision](https://pytorch.org/vision/), [torchaudio](https://pytorch.org/audio/), [torchtext](https://pytorch.org/text/), [cherry](http://cherry-rl.net/), and any other PyTorch-based library you might be using. -learn2learn provides high- and low-level utilities for meta-learning. -The high-level utilities allow arbitrary users to take advantage of exisiting meta-learning algorithms. -The low-level utilities enable researchers to develop new and better meta-learning algorithms. +**Overview** -Some features of learn2learn include: +* [`learn2learn.data`](http://learn2learn.net/docs/learn2learn.data/): `TaskDataset` and transforms to create few-shot tasks from any PyTorch dataset. +* [`learn2learn.vision`](http://learn2learn.net/docs/learn2learn.vision/): Models, datasets, and benchmarks for computer vision and few-shot learning. +* [`learn2learn.gym`](http://learn2learn.net/docs/learn2learn.gym/): Environment and utilities for meta-reinforcement learning. +* [`learn2learn.algorithms`](http://learn2learn.net/docs/learn2learn.algorithms/): High-level wrappers for existing meta-learning algorithms. +* [`learn2learn.optim`](http://learn2learn.net/docs/learn2learn.optim/): Utilities and algorithms for differentiable optimization and meta-descent. -* Modular API: implement your own training loops with our low-level utilities. -* Provides various meta-learning algorithms (e.g. MAML, FOMAML, MetaSGD, ProtoNets, DiCE) -* Task generator with unified API, compatible with torchvision, torchtext, torchaudio, and cherry. -* Provides standardized meta-learning tasks for vision (Omniglot, mini-ImageNet), reinforcement learning (Particles, Mujoco), and even text (news classification). -* 100% compatible with PyTorch -- use your own modules, datasets, or libraries! +**Resources** + +* Website: [http://learn2learn.net/](http://learn2learn.net/) +* Documentation: [http://learn2learn.net/docs/](http://learn2learn.net/docs/) +* Tutorials: [http://learn2learn.net/tutorials/getting_started/](http://learn2learn.net/tutorials/getting_started/) +* Examples: [https://github.com/learnables/learn2learn/tree/master/examples](https://github.com/learnables/learn2learn/tree/master/examples) +* GitHub: [https://github.com/learnables/learn2learn/](https://github.com/learnables/learn2learn/) +* Slack: [http://slack.learn2learn.net/](http://slack.learn2learn.net/) ## Installation @@ -32,53 +37,109 @@ Some features of learn2learn include: pip install learn2learn ~~~ -## API Demo +## Snippets & Examples + +The following snippets provide a sneak peek at the functionalities of learn2learn. + +### High-level Wrappers -The following is an example of using the high-level MAML implementation on MNIST. -For more algorithms and lower-level utilities, please refer to the [documentation](http://learn2learn.net/docs/learn2learn/) or the [examples](https://github.com/learnables/learn2learn/tree/master/examples). +**Few-Shot Learning with MAML** +For more algorithms (ProtoNets, ANIL, Meta-SGD, Reptile, Meta-Curvature, KFO) refer to the [examples](https://github.com/learnables/learn2learn/tree/master/examples/vision) folder. +Most of them can be implemented with with the `GBML` wrapper. ([documentation](http://learn2learn.net/docs/learn2learn.algorithms/#gbml)). ~~~python -import learn2learn as l2l - -mnist = torchvision.datasets.MNIST(root="/tmp/mnist", train=True) - -mnist = l2l.data.MetaDataset(mnist) -train_tasks = l2l.data.TaskDataset(mnist, - task_transforms=[ - NWays(mnist, n=3), - KShots(mnist, k=1), - LoadData(mnist), - ], - num_tasks=10) -model = Net() -maml = l2l.algorithms.MAML(model, lr=1e-3, first_order=False) -opt = optim.Adam(maml.parameters(), lr=4e-3) - -for iteration in range(num_iterations): - learner = maml.clone() # Creates a clone of model - for task in train_tasks: - # Split task in adaptation_task and evalutation_task - # Fast adapt - for step in range(adaptation_steps): - error = compute_loss(adaptation_task) - learner.adapt(error) - - # Compute evaluation loss - evaluation_error = compute_loss(evaluation_task) - - # Meta-update the model parameters - opt.zero_grad() - evaluation_error.backward() - opt.step() +maml = l2l.algorithms.MAML(model, lr=0.1) +opt = torch.optim.SGD(maml.parameters(), lr=0.001) +for iteration in range(10): + opt.zero_grad() + task_model = maml.clone() # torch.clone() for nn.Modules + adaptation_loss = compute_loss(task_model) + task_model.adapt(adaptation_loss) # computes gradient, update task_model in-place + evaluation_loss = compute_loss(task_model) + evaluation_loss.backward() # gradients w.r.t. maml.parameters() + opt.step() ~~~ -## Changelog +**Meta-Descent with Hypergradient** -A human-readable changelog is available in the [CHANGELOG.md](CHANGELOG.md) file. +Learn any kind of optimization algorithm with the `LearnableOptimizer`. ([example](https://github.com/learnables/learn2learn/tree/master/examples/optimization) and [documentation](http://learn2learn.net/docs/learn2learn.optim/#learnableoptimizer)) +~~~python +linear = nn.Linear(784, 10) +transform = l2l.optim.ModuleTransform(l2l.nn.Scale) +metaopt = l2l.optim.LearnableOptimizer(linear, transform, lr=0.01) # metaopt has .step() +opt = torch.optim.SGD(metaopt.parameters(), lr=0.001) # metaopt also has .parameters() + +metaopt.zero_grad() +opt.zero_grad() +error = loss(linear(X), y) +error.backward() +opt.step() # update metaopt +metaopt.step() # update linear +~~~ + +### Learning Domains + +**Custom Few-Shot Dataset** + +Many standardized datasets (Omniglot, mini-/tiered-ImageNet, FC100, CIFAR-FS) are readily available in `learn2learn.vision.datasets`. +([documentation](http://learn2learn.net/docs/learn2learn.vision/#learn2learnvisiondatasets)) +~~~python +dataset = l2l.data.MetaDataset(MyDataset()) # any PyTorch dataset +transforms = [ # Easy to define your own transform + l2l.data.transforms.NWays(dataset, n=5), + l2l.data.transforms.KShots(dataset, k=1), + l2l.data.transforms.LoadData(dataset), +] +taskset = TaskDataset(dataset, transforms, num_tasks=20000) +for task in taskset: + X, y = task + # Meta-train on the task +~~~ -## Documentation +**Environments and Utilities for Meta-RL** -Documentation and tutorials are available on learn2learn’s website: [http://learn2learn.net](http://learn2learn.net). +Parallelize your own meta-environments with `AsyncVectorEnv`, or use the standardized ones. +([documentation](http://learn2learn.net/docs/learn2learn.gym/#metaenv)) +~~~python +def make_env(): + env = l2l.gym.HalfCheetahForwardBackwardEnv() + env = cherry.envs.ActionSpaceScaler(env) + return env + +env = l2l.gym.AsyncVectorEnv([make_env for _ in range(16)]) # uses 16 threads +for task_config in env.sample_tasks(20): + env.set_task(task) # all threads receive the same task + state = env.reset() # use standard Gym API + action = my_policy(env) + env.step(action) +~~~ + +### Low-Level Utilities + +**Differentiable Optimization** + +Learn and differentiate through updates of PyTorch Modules. +([documentation](http://learn2learn.net/docs/learn2learn.optim/#parameterupdate)) +~~~python + +model = MyModel() +transform = l2l.optim.KroneckerTransform(l2l.nn.KroneckerLinear) +learned_update = l2l.optim.ParameterUpdate( # learnable update function + model.parameters(), transform) +clone = l2l.clone_module(model) # torch.clone() for nn.Modules +error = loss(clone(X), y) +updates = learned_update( # similar API as torch.autograd.grad + error, + clone.parameters(), + create_graph=True, +) +l2l.update_module(clone, updates=updates) +loss(clone(X), y).backward() # Gradients w.r.t model.parameters() and learned_update.parameters() +~~~ + +## Changelog + +A human-readable changelog is available in the [CHANGELOG.md](CHANGELOG.md) file. ## Citation @@ -101,5 +162,5 @@ You can also use the following Bibtex entry. ### Acknowledgements & Friends 1. The RL environments are adapted from Tristan Deleu's [implementations](https://github.com/tristandeleu/pytorch-maml-rl) and from the ProMP [repository](https://github.com/jonasrothfuss/ProMP/). Both shared with permission, under the MIT License. -2. [TorchMeta](https://github.com/tristandeleu/pytorch-meta) is similar library, with a focus on supervised meta-learning. If learn2learn were missing a particular functionality, we would go check if TorchMeta has it. But we would also open an issue ;) -3. [higher](https://github.com/facebookresearch/higher) is a PyTorch library that also enables differentiating through optimization inner-loops. Their approach is different from learn2learn in that they monkey-patch nn.Module to be stateless. For more information, refer to [their ArXiv paper](https://arxiv.org/abs/1910.01727). +2. [TorchMeta](https://github.com/tristandeleu/pytorch-meta) is similar library, with a focus on datasets for supervised meta-learning. +3. [higher](https://github.com/facebookresearch/higher) is a PyTorch library that enables differentiating through optimization inner-loops. While they monkey-patch `nn.Module` to be stateless, learn2learn retains the stateful PyTorch look-and-feel. For more information, refer to [their ArXiv paper](https://arxiv.org/abs/1910.01727). diff --git a/docs/pydocmd.yml b/docs/pydocmd.yml index 00336a75..21d5e8a4 100644 --- a/docs/pydocmd.yml +++ b/docs/pydocmd.yml @@ -6,9 +6,10 @@ site_name: "learn2learn" # documented. Higher indentation leads to smaller header size. generate: - docs/learn2learn.md: - - learn2learn.utils: + - learn2learn: - learn2learn.clone_module - learn2learn.detach_module + - learn2learn.update_module - learn2learn.magic_box - docs/learn2learn.data.md: - learn2learn.data: @@ -25,9 +26,8 @@ generate: - docs/learn2learn.algorithms.md: - learn2learn.algorithms: - learn2learn.algorithms.MAML++ - - learn2learn.algorithms.maml_update - learn2learn.algorithms.MetaSGD++ - - learn2learn.algorithms.meta_sgd_update + - learn2learn.algorithms.GBML++ - docs/learn2learn.gym.md: - learn2learn.gym++: - learn2learn.gym.MetaEnv @@ -40,6 +40,27 @@ generate: - learn2learn.gym.envs.mujoco.HumanoidDirectionEnv - learn2learn.gym.envs.particles: - learn2learn.gym.envs.particles.Particles2DEnv + - learn2learn.gym.envs.metaworld: + - learn2learn.gym.envs.metaworld.MetaWorldML1++ + - learn2learn.gym.envs.metaworld.MetaWorldML10++ + - learn2learn.gym.envs.metaworld.MetaWorldML45++ + - docs/learn2learn.optim.md: + - learn2learn.optim++: + - learn2learn.optim.LearnableOptimizer++ + - learn2learn.optim.ParameterUpdate++ + - learn2learn.optim.DifferentiableSGD++ + - learn2learn.optim.transforms: + - learn2learn.optim.transforms.ModuleTransform++ + - learn2learn.optim.transforms.KroneckerTransform++ + - learn2learn.optim.transforms.MetaCurvatureTransform++ + - docs/learn2learn.nn.md: + - learn2learn.nn++: + - learn2learn.nn.Lambda + - learn2learn.nn.Flatten + - learn2learn.nn.Scale + - learn2learn.nn.KroneckerLinear + - learn2learn.nn.KroneckerRNN + - learn2learn.nn.KroneckerLSTM - docs/learn2learn.vision.md: - learn2learn.vision++: - learn2learn.vision.models: @@ -73,13 +94,16 @@ pages: - Feature Reuse with ANIL: tutorials/anil_tutorial/ANIL_tutorial.md - Documentation: - learn2learn: docs/learn2learn.md - - learn2learn.algorithms: docs/learn2learn.algorithms.md - learn2learn.data: docs/learn2learn.data.md - - learn2learn.gym: docs/learn2learn.gym.md + - learn2learn.algorithms: docs/learn2learn.algorithms.md + - learn2learn.optim: docs/learn2learn.optim.md + - learn2learn.nn: docs/learn2learn.nn.md - learn2learn.vision: docs/learn2learn.vision.md + - learn2learn.gym: docs/learn2learn.gym.md - Examples: - Computer Vision: examples.vision.md << ../examples/vision/README.md - Reinforcement Learning: examples.rl.md << ../examples/rl/README.md + - Optimization: examples.optim.md << ../examples/optimization/README.md - Changelog: changelog.md << ../CHANGELOG.md - GitHub: https://github.com/learnables/learn2learn/ diff --git a/examples/optimization/README.md b/examples/optimization/README.md new file mode 100644 index 00000000..d362c695 --- /dev/null +++ b/examples/optimization/README.md @@ -0,0 +1,22 @@ +# Meta-Optimization + +This directory contains examples of using learn2learn for meta-optimization or meta-descent. + +# Hypergradient + +The script `hypergrad_mnist.py` demonstrates how to implement a slightly modified version of "[Online Learning Rate Adaptation with Hypergradient Descent](https://arxiv.org/abs/1703.04782)". +The implementation departs from the algorithm presented in the paper in two ways. + +1. We forgo the analytical formulation of the learning rate's gradient to demonstrate the capability of the `LearnableOptimizer` class. +2. We adapt per-parameter learning rates instead of updating a single learning rate shared by all parameters. + +**Usage** + +!!! warning + The parameters for this script were not carefully tuned. + +Manually edit the script and run: + +~~~shell +python examples/optimization/hypergrad_mnist.py +~~~ diff --git a/examples/optimization/hypergrad_mnist.py b/examples/optimization/hypergrad_mnist.py new file mode 100644 index 00000000..62d2fb36 --- /dev/null +++ b/examples/optimization/hypergrad_mnist.py @@ -0,0 +1,135 @@ +#!/usr/bin/env python3 + +""" +File: hypergrad_mnist.py +Author: Seb Arnold - seba1511.net +Email: smr.arnold@gmail.com +Github: seba-1511 +Description: Demonstation of the LearnableOptimizer to optimize a CNN on MNIST. + +While this example is inspired form the hypergradient literature, it differs +from Hypergradient: + 1. We do not use the analytical expression for the hypergradient, but + instead rely on autograd to compute it for us. + 2. We learn a per-parameter learning rate rather than one shared across + all parameters. + +The network is inspired from the official MNIST example, in the PyTorch repo. +""" + +import torch +from torch.nn import functional as F +import torchvision as tv +import learn2learn as l2l +import tqdm + + +def accuracy(predictions, targets): + """Returns mean accuracy over a mini-batch""" + predictions = predictions.argmax(dim=1).view(targets.shape) + return (predictions == targets).sum().float() / targets.size(0) + + +class Net(torch.nn.Module): + def __init__(self): + super(Net, self).__init__() + self.conv1 = torch.nn.Conv2d(1, 32, 3, 1) + self.conv2 = torch.nn.Conv2d(32, 64, 3, 1) + self.dropout1 = torch.nn.Dropout2d(0.25) + self.dropout2 = torch.nn.Dropout2d(0.5) + self.fc1 = torch.nn.Linear(9216, 128) + self.fc2 = torch.nn.Linear(128, 10) + + def forward(self, x): + x = self.conv1(x) + x = F.relu(x) + x = self.conv2(x) + x = F.relu(x) + x = F.max_pool2d(x, 2) + x = self.dropout1(x) + x = torch.flatten(x, 1) + x = self.fc1(x) + x = F.relu(x) + x = self.dropout2(x) + x = self.fc2(x) + output = F.log_softmax(x, dim=1) + return output + + +class HypergradTransform(torch.nn.Module): + """Hypergradient-style per-parameter learning rates""" + + def __init__(self, param, lr=0.01): + super(HypergradTransform, self).__init__() + self.lr = lr * torch.ones_like(param, requires_grad=True) + self.lr = torch.nn.Parameter(self.lr) + + def forward(self, grad): + return self.lr * grad + + +def main(): + torch.manual_seed(1234) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = Net() + model.to(device) + metaopt = l2l.optim.LearnableOptimizer( + model=model, # We pass the model, not its parameters + transform=HypergradTransform, # Any transform could work + lr=0.1) + metaopt.to(device) # metaopt inherits from torch.nn.Module + opt = torch.optim.Adam(metaopt.parameters(), lr=3e-4) + loss = torch.nn.NLLLoss() + + kwargs = {'num_workers': 1, + 'pin_memory': True} if torch.cuda.is_available() else {} + train_loader = torch.utils.data.DataLoader( + tv.datasets.MNIST('~/data', train=True, download=True, + transform=tv.transforms.Compose([ + tv.transforms.ToTensor(), + tv.transforms.Normalize((0.1307,), (0.3081,)) + ])), + batch_size=32, shuffle=True, **kwargs) + test_loader = torch.utils.data.DataLoader( + tv.datasets.MNIST('~/data', train=False, transform=tv.transforms.Compose([ + tv.transforms.ToTensor(), + tv.transforms.Normalize((0.1307,), (0.3081,)) + ])), + batch_size=128, shuffle=False, **kwargs) + + for epoch in range(10): + # Train for an epoch + model.train() + for X, y in tqdm.tqdm(train_loader, leave=False): + X, y = X.to(device), y.to(device) + metaopt.zero_grad() + opt.zero_grad() + err = loss(model(X), y) + err.backward() + opt.step() # Update metaopt parameters + metaopt.step() # Update model parameters + + # Compute test error + model.eval() + test_error = 0.0 + test_accuracy = 0.0 + with torch.no_grad(): + for X, y in test_loader: + X, y = X.to(device), y.to(device) + preds = model(X) + test_error += loss(preds, y) + test_accuracy += accuracy(preds, y) + test_error /= len(test_loader) + test_accuracy /= len(test_loader) + print('\nEpoch', epoch) + print('Loss:', test_error.item()) + print('Accuracy:', test_accuracy.item()) + + # Print the learned learning rates of the model + print('The learning rates were:') + for p in metaopt.parameters(): + print(p) + + +if __name__ == '__main__': + main() diff --git a/examples/vision/anilkfo_cifarfs.py b/examples/vision/anilkfo_cifarfs.py new file mode 100644 index 00000000..7955fb32 --- /dev/null +++ b/examples/vision/anilkfo_cifarfs.py @@ -0,0 +1,232 @@ +#!/usr/bin/env python3 + +""" +File: anilkfo_cifarfs.py +Author: Seb Arnold - seba1511.net +Email: smr.arnold@gmail.com +Github: seba-1511 +Description: +Demonstrates how to use the low-level differentiable optimization utilities +to implement ANIL+KFC on CIFAR-FS. + +A demonstration of the high-level API is available in: + examples/vision/metacurvature_fc100.py +""" + +import random +import numpy as np +import torch +import learn2learn as l2l + + +class CifarCNN(torch.nn.Module): + """ + Example of a 4-layer CNN network for FC100/CIFAR-FS. + """ + + def __init__(self, output_size=5, hidden_size=32, layers=4): + super(CifarCNN, self).__init__() + self.hidden_size = hidden_size + features = l2l.vision.models.ConvBase( + output_size=hidden_size, + hidden=hidden_size, + channels=3, + max_pool=False, + layers=layers, + max_pool_factor=0.5, + ) + self.features = torch.nn.Sequential( + features, + l2l.nn.Lambda(lambda x: x.mean(dim=[2, 3])), + l2l.nn.Flatten(), + ) + self.linear = torch.nn.Linear(self.hidden_size, output_size, bias=True) + l2l.vision.models.maml_init_(self.linear) + + def forward(self, x): + x = self.features(x) + x = self.linear(x) + return x + + +def accuracy(predictions, targets): + predictions = predictions.argmax(dim=1).view(targets.shape) + return (predictions == targets).sum().float() / targets.size(0) + + +def fast_adapt( + batch, + features, + classifier, + update, + diff_sgd, + loss, + adaptation_steps, + shots, + ways, + device): + data, labels = batch + data, labels = data.to(device), labels.to(device) + data = features(data) + + # Separate data into adaptation/evalutation sets + adaptation_indices = np.zeros(data.size(0), dtype=bool) + adaptation_indices[np.arange(shots*ways) * 2] = True + evaluation_indices = torch.from_numpy(~adaptation_indices) + adaptation_indices = torch.from_numpy(adaptation_indices) + adaptation_data, adaptation_labels = data[adaptation_indices], labels[adaptation_indices] + evaluation_data, evaluation_labels = data[evaluation_indices], labels[evaluation_indices] + + # Adapt the model & learned update + for step in range(adaptation_steps): + adaptation_error = loss(classifier(adaptation_data), adaptation_labels) + if step > 0: # Update the learnable update function + update_grad = torch.autograd.grad(adaptation_error, + update.parameters(), + create_graph=True, + retain_graph=True) + diff_sgd(update, update_grad) + classifier_updates = update(adaptation_error, + classifier.parameters(), + create_graph=True, + retain_graph=True) + diff_sgd(classifier, classifier_updates) + + # Evaluate the adapted model + predictions = classifier(evaluation_data) + eval_error = loss(predictions, evaluation_labels) + eval_accuracy = accuracy(predictions, evaluation_labels) + return eval_error, eval_accuracy + + +def main( + fast_lr=0.1, + meta_lr=0.003, + num_iterations=10000, + meta_batch_size=16, + adaptation_steps=5, + shots=5, + ways=5, + cuda=1, + seed=1234 + ): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + device = torch.device('cpu') + if cuda and torch.cuda.device_count(): + torch.cuda.manual_seed(seed) + device = torch.device('cuda') + + # Create Tasksets using the benchmark interface + tasksets = l2l.vision.benchmarks.get_tasksets( + name='cifarfs', + train_samples=2*shots, + train_ways=ways, + test_samples=2*shots, + test_ways=ways, + root='~/data', + ) + + # Create model and learnable update + model = CifarCNN(output_size=ways) + model.to(device) + features = model.features + classifier = model.linear + kfo_transform = l2l.optim.transforms.KroneckerTransform(l2l.nn.KroneckerLinear) + fast_update = l2l.optim.ParameterUpdate( + parameters=classifier.parameters(), + transform=kfo_transform, + ) + fast_update.to(device) + diff_sgd = l2l.optim.DifferentiableSGD(lr=fast_lr) + + all_parameters = list(model.parameters()) + list(fast_update.parameters()) + opt = torch.optim.Adam(all_parameters, meta_lr) + loss = torch.nn.CrossEntropyLoss(reduction='mean') + + for iteration in range(num_iterations): + opt.zero_grad() + meta_train_error = 0.0 + meta_train_accuracy = 0.0 + meta_valid_error = 0.0 + meta_valid_accuracy = 0.0 + for task in range(meta_batch_size): + # Compute meta-training loss + task_features = l2l.clone_module(features) + task_classifier = l2l.clone_module(classifier) + task_update = l2l.clone_module(fast_update) + batch = tasksets.train.sample() + evaluation_error, evaluation_accuracy = fast_adapt(batch, + task_features, + task_classifier, + task_update, + diff_sgd, + loss, + adaptation_steps, + shots, + ways, + device) + evaluation_error.backward() + meta_train_error += evaluation_error.item() + meta_train_accuracy += evaluation_accuracy.item() + + # Compute meta-validation loss + task_features = l2l.clone_module(features) + task_classifier = l2l.clone_module(classifier) + task_update = l2l.clone_module(fast_update) + batch = tasksets.validation.sample() + evaluation_error, evaluation_accuracy = fast_adapt(batch, + task_features, + task_classifier, + task_update, + diff_sgd, + loss, + adaptation_steps, + shots, + ways, + device) + meta_valid_error += evaluation_error.item() + meta_valid_accuracy += evaluation_accuracy.item() + + # Print some metrics + print('\n') + print('Iteration', iteration) + print('Meta Train Error', meta_train_error / meta_batch_size) + print('Meta Train Accuracy', meta_train_accuracy / meta_batch_size) + print('Meta Valid Error', meta_valid_error / meta_batch_size) + print('Meta Valid Accuracy', meta_valid_accuracy / meta_batch_size) + + # Average the accumulated gradients and optimize + for p in model.parameters(): + p.grad.data.mul_(1.0 / meta_batch_size) + for p in fast_update.parameters(): + p.grad.data.mul_(1.0 / meta_batch_size) + opt.step() + + meta_test_error = 0.0 + meta_test_accuracy = 0.0 + for task in range(meta_batch_size): + # Compute meta-testing loss + task_features = l2l.clone_module(features) + task_classifier = l2l.clone_module(classifier) + task_update = l2l.clone_module(fast_update) + batch = tasksets.test.sample() + evaluation_error, evaluation_accuracy = fast_adapt(batch, + task_features, + task_classifier, + task_update, + diff_sgd, + loss, + adaptation_steps, + shots, + ways, + device) + meta_test_error += evaluation_error.item() + meta_test_accuracy += evaluation_accuracy.item() + print('Meta Test Error', meta_test_error / meta_batch_size) + print('Meta Test Accuracy', meta_test_accuracy / meta_batch_size) + + +if __name__ == '__main__': + main() diff --git a/examples/vision/maml_miniimagenet.py b/examples/vision/maml_miniimagenet.py index 81cbed5b..c57fd8d1 100644 --- a/examples/vision/maml_miniimagenet.py +++ b/examples/vision/maml_miniimagenet.py @@ -42,16 +42,16 @@ def fast_adapt(batch, learner, loss, adaptation_steps, shots, ways, device): # Adapt the model for step in range(adaptation_steps): - train_error = loss(learner(adaptation_data), adaptation_labels) - train_error /= len(adaptation_data) - learner.adapt(train_error) + adaptation_error = loss(learner(adaptation_data), adaptation_labels) + adaptation_error /= len(adaptation_data) + learner.adapt(adaptation_error) # Evaluate the adapted model predictions = learner(evaluation_data) - valid_error = loss(predictions, evaluation_labels) - valid_error /= len(evaluation_data) - valid_accuracy = accuracy(predictions, evaluation_labels) - return valid_error, valid_accuracy + evaluation_error = loss(predictions, evaluation_labels) + evaluation_error /= len(evaluation_data) + evaluation_accuracy = accuracy(predictions, evaluation_labels) + return evaluation_error, evaluation_accuracy def main( diff --git a/examples/vision/metacurvature_fc100.py b/examples/vision/metacurvature_fc100.py new file mode 100644 index 00000000..4b9ed1f6 --- /dev/null +++ b/examples/vision/metacurvature_fc100.py @@ -0,0 +1,190 @@ +#!/usr/bin/env python3 + +""" +File: metacurvature_fc100.py +Author: Seb Arnold - seba1511.net +Email: smr.arnold@gmail.com +Github: seba-1511 +Description: +Demonstrates how to use the GBML wrapper to implement MetaCurvature. + +A demonstration of the low-level API is available in: + examples/vision/anilkfo_cifarfs.py +""" + +import random +import numpy as np +import torch +import learn2learn as l2l +from learn2learn.optim.transforms import MetaCurvatureTransform + + +class CifarCNN(torch.nn.Module): + """ + Example of a 4-layer CNN network for FC100/CIFAR-FS. + """ + + def __init__(self, output_size=5, hidden_size=32, layers=4): + super(CifarCNN, self).__init__() + self.hidden_size = hidden_size + features = l2l.vision.models.ConvBase( + output_size=hidden_size, + hidden=hidden_size, + channels=3, + max_pool=False, + layers=layers, + max_pool_factor=0.5, + ) + self.features = torch.nn.Sequential( + features, + l2l.nn.Lambda(lambda x: x.mean(dim=[2, 3])), + l2l.nn.Flatten(), + ) + self.linear = torch.nn.Linear(self.hidden_size, output_size, bias=True) + l2l.vision.models.maml_init_(self.linear) + + def forward(self, x): + x = self.features(x) + x = self.linear(x) + return x + + +def accuracy(predictions, targets): + predictions = predictions.argmax(dim=1).view(targets.shape) + return (predictions == targets).sum().float() / targets.size(0) + + +def fast_adapt(batch, learner, loss, adaptation_steps, shots, ways, device): + data, labels = batch + data, labels = data.to(device), labels.to(device) + + # Separate data into adaptation/evalutation sets + adaptation_indices = np.zeros(data.size(0), dtype=bool) + adaptation_indices[np.arange(shots*ways) * 2] = True + evaluation_indices = torch.from_numpy(~adaptation_indices) + adaptation_indices = torch.from_numpy(adaptation_indices) + adaptation_data, adaptation_labels = data[adaptation_indices], labels[adaptation_indices] + evaluation_data, evaluation_labels = data[evaluation_indices], labels[evaluation_indices] + + # Adapt the model + for step in range(adaptation_steps): + adaptation_error = loss(learner(adaptation_data), adaptation_labels) + learner.adapt(adaptation_error) + + # Evaluate the adapted model + predictions = learner(evaluation_data) + evaluation_error = loss(predictions, evaluation_labels) + evaluation_accuracy = accuracy(predictions, evaluation_labels) + return evaluation_error, evaluation_accuracy + + +def main( + fast_lr=0.1, + meta_lr=0.01, + num_iterations=10000, + meta_batch_size=16, + adaptation_steps=5, + shots=5, + ways=5, + cuda=1, + seed=1234 + ): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + device = torch.device('cpu') + if cuda and torch.cuda.device_count(): + torch.cuda.manual_seed(seed) + device = torch.device('cuda') + + # Create Tasksets using the benchmark interface + tasksets = l2l.vision.benchmarks.get_tasksets( + name='fc100', + train_samples=2*shots, + train_ways=ways, + test_samples=2*shots, + test_ways=ways, + root='~/data', + ) + + # Create model + model = CifarCNN(output_size=ways) + model.to(device) + gbml = l2l.algorithms.GBML( + model, + transform=MetaCurvatureTransform, + lr=fast_lr, + adapt_transform=False, + ) + gbml.to(device) + opt = torch.optim.Adam(gbml.parameters(), meta_lr) + loss = torch.nn.CrossEntropyLoss(reduction='mean') + + for iteration in range(num_iterations): + opt.zero_grad() + meta_train_error = 0.0 + meta_train_accuracy = 0.0 + meta_valid_error = 0.0 + meta_valid_accuracy = 0.0 + for task in range(meta_batch_size): + # Compute meta-training loss + learner = gbml.clone() + batch = tasksets.train.sample() + evaluation_error, evaluation_accuracy = fast_adapt(batch, + learner, + loss, + adaptation_steps, + shots, + ways, + device) + evaluation_error.backward() + meta_train_error += evaluation_error.item() + meta_train_accuracy += evaluation_accuracy.item() + + # Compute meta-validation loss + learner = gbml.clone() + batch = tasksets.validation.sample() + evaluation_error, evaluation_accuracy = fast_adapt(batch, + learner, + loss, + adaptation_steps, + shots, + ways, + device) + meta_valid_error += evaluation_error.item() + meta_valid_accuracy += evaluation_accuracy.item() + + # Print some metrics + print('\n') + print('Iteration', iteration) + print('Meta Train Error', meta_train_error / meta_batch_size) + print('Meta Train Accuracy', meta_train_accuracy / meta_batch_size) + print('Meta Valid Error', meta_valid_error / meta_batch_size) + print('Meta Valid Accuracy', meta_valid_accuracy / meta_batch_size) + + # Average the accumulated gradients and optimize + for p in gbml.parameters(): + p.grad.data.mul_(1.0 / meta_batch_size) + opt.step() + + meta_test_error = 0.0 + meta_test_accuracy = 0.0 + for task in range(meta_batch_size): + # Compute meta-testing loss + learner = gbml.clone() + batch = tasksets.test.sample() + evaluation_error, evaluation_accuracy = fast_adapt(batch, + learner, + loss, + adaptation_steps, + shots, + ways, + device) + meta_test_error += evaluation_error.item() + meta_test_accuracy += evaluation_accuracy.item() + print('Meta Test Error', meta_test_error / meta_batch_size) + print('Meta Test Accuracy', meta_test_accuracy / meta_batch_size) + + +if __name__ == '__main__': + main() diff --git a/learn2learn/__init__.py b/learn2learn/__init__.py index 39797d5e..d503ff28 100644 --- a/learn2learn/__init__.py +++ b/learn2learn/__init__.py @@ -1,9 +1,11 @@ #!/usr/bin/env python3 +from ._version import __version__ from . import algorithms from . import data from . import gym from . import text from . import vision -from ._version import __version__ +from . import optim +from . import nn from .utils import * diff --git a/learn2learn/_version.py b/learn2learn/_version.py index df9144c5..10939f01 100644 --- a/learn2learn/_version.py +++ b/learn2learn/_version.py @@ -1 +1 @@ -__version__ = '0.1.1' +__version__ = '0.1.2' diff --git a/learn2learn/algorithms/__init__.py b/learn2learn/algorithms/__init__.py index 7ddccb4e..27b3eee3 100644 --- a/learn2learn/algorithms/__init__.py +++ b/learn2learn/algorithms/__init__.py @@ -1,4 +1,9 @@ #!/usr/bin/env python3 +r""" +A set of high-level algorithm implementations, with easy-to-use API. +""" + from .maml import MAML, maml_update from .meta_sgd import MetaSGD, meta_sgd_update +from .gbml import GBML diff --git a/learn2learn/algorithms/gbml.py b/learn2learn/algorithms/gbml.py new file mode 100644 index 00000000..24d230f1 --- /dev/null +++ b/learn2learn/algorithms/gbml.py @@ -0,0 +1,202 @@ +#!/usr/bin/env python3 + +import torch +import learn2learn as l2l + + +class GBML(torch.nn.Module): + """ + + [[Source]](https://github.com/learnables/learn2learn/blob/master/learn2learn/algorithms/gbml.py) + + **Description** + + General wrapper for gradient-based meta-learning implementations. + + A variety of algorithms can simply be implemented by changing the kind + of `transform` used during fast-adaptation. + For example, if the transform is `Scale` we recover Meta-SGD [2] with `adapt_transform=False` + and Alpha MAML [4] with `adapt_transform=True`. + If the transform is a Kronecker-factored module (e.g. neural network, or linear), we recover + KFO from [5]. + + **Arguments** + + * **module** (Module) - Module to be wrapped. + * **tranform** (Module) - Transform used to update the module. + * **lr** (float) - Fast adaptation learning rate. + * **adapt_transform** (bool, *optional*, default=False) - Whether to update the transform's + parameters during fast-adaptation. + * **first_order** (bool, *optional*, default=False) - Whether to use the first-order + approximation. + * **allow_unused** (bool, *optional*, default=None) - Whether to allow differentiation + of unused parameters. Defaults to `allow_nograd`. + * **allow_nograd** (bool, *optional*, default=False) - Whether to allow adaptation with + parameters that have `requires_grad = False`. + + **References** + + 1. Finn et al. 2017. “Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks.” + 2. Li et al. 2017. “Meta-SGD: Learning to Learn Quickly for Few-Shot Learning.” + 3. Park & Oliva. 2019. “Meta-Curvature.” + 4. Behl et al. 2019. “Alpha MAML: Adaptive Model-Agnostic Meta-Learning.” + 5. Arnold et al. 2019. “When MAML Can Adapt Fast and How to Assist When It Cannot.” + + **Example** + + ~~~python + model = SmallCNN() + transform = l2l.optim.ModuleTransform(torch.nn.Linear) + gbml = l2l.algorithms.GBML( + module=model, + transform=transform, + lr=0.01, + adapt_transform=True, + ) + gbml.to(device) + opt = torch.optim.SGD(gbml.parameters(), lr=0.001) + + # Training with 1 adaptation step + for iteration in range(10): + opt.zero_grad() + task_model = gbml.clone() + loss = compute_loss(task_model) + task_model.adapt(loss) + loss.backward() + opt.step() + ~~~ + """ + + def __init__( + self, + module, + transform, + lr=1.0, + adapt_transform=False, + first_order=False, + allow_unused=False, + allow_nograd=False, + **kwargs, + ): + super(GBML, self).__init__() + self.module = module + self.transform = transform + self.adapt_transform = adapt_transform + self.lr = lr + self.first_order = first_order + self.allow_unused = allow_unused + self.allow_nograd = allow_nograd + if 'compute_update' in kwargs: + self.compute_update = kwargs.get('compute_update') + else: + self.compute_update = l2l.optim.ParameterUpdate( + parameters=self.module.parameters(), + transform=transform, + ) + self.diff_sgd = l2l.optim.DifferentiableSGD(lr=self.lr) + # Whether the module params have already been updated with the + # updates from compute_update. Used to keep track of whether we + # can compute the gradient of compute_update's parameters. + self._params_updated = False + + def forward(self, *args, **kwargs): + return self.module(*args, **kwargs) + + def clone( + self, + first_order=None, + allow_unused=None, + allow_nograd=None, + adapt_transform=None, + ): + """ + **Description** + + Similar to `MAML.clone()`. + + **Arguments** + + * **first_order** (bool, *optional*, default=None) - Whether the clone uses first- + or second-order updates. Defaults to self.first_order. + * **allow_unused** (bool, *optional*, default=None) - Whether to allow differentiation + of unused parameters. Defaults to self.allow_unused. + * **allow_nograd** (bool, *optional*, default=False) - Whether to allow adaptation with + parameters that have `requires_grad = False`. Defaults to self.allow_nograd. + + """ + if first_order is None: + first_order = self.first_order + if allow_unused is None: + allow_unused = self.allow_unused + if allow_nograd is None: + allow_nograd = self.allow_nograd + if adapt_transform is None: + adapt_transform = self.adapt_transform + module_clone = l2l.clone_module(self.module) + update_clone = l2l.clone_module(self.compute_update) + return GBML( + module=module_clone, + transform=self.transform, + lr=self.lr, + adapt_transform=adapt_transform, + first_order=first_order, + allow_unused=allow_unused, + allow_nograd=allow_nograd, + compute_update=update_clone, + ) + + def adapt( + self, + loss, + first_order=None, + allow_nograd=None, + allow_unused=None, + ): + """ + **Description** + + Takes a gradient step on the loss and updates the cloned parameters in place. + + The parameters of the transform are only adapted if `self.adapt_update` is `True`. + + **Arguments** + + * **loss** (Tensor) - Loss to minimize upon update. + * **first_order** (bool, *optional*, default=None) - Whether to use first- or + second-order updates. Defaults to self.first_order. + * **allow_unused** (bool, *optional*, default=None) - Whether to allow differentiation + of unused parameters. Defaults to self.allow_unused. + * **allow_nograd** (bool, *optional*, default=None) - Whether to allow adaptation with + parameters that have `requires_grad = False`. Defaults to self.allow_nograd. + """ + if first_order is None: + first_order = self.first_order + if allow_unused is None: + allow_unused = self.allow_unused + if allow_nograd is None: + allow_nograd = self.allow_nograd + second_order = not first_order + + if self.adapt_transform and self._params_updated: + # Update the learnable update function + update_grad = torch.autograd.grad( + loss, + self.compute_update.parameters(), + create_graph=second_order, + retain_graph=second_order, + allow_unused=allow_unused, + ) + self.diff_sgd(self.compute_update, update_grad) + self._params_updated = False + + # Update the module + updates = self.compute_update( + loss, + self.module.parameters(), + create_graph=second_order or self.adapt_transform, + retain_graph=second_order or self.adapt_transform, + allow_unused=allow_unused, + allow_nograd=allow_nograd, + ) + self.diff_sgd(self.module, updates) + self._params_updated = True diff --git a/learn2learn/algorithms/maml.py b/learn2learn/algorithms/maml.py index f4094fa7..08fa2638 100644 --- a/learn2learn/algorithms/maml.py +++ b/learn2learn/algorithms/maml.py @@ -4,7 +4,7 @@ from torch.autograd import grad from learn2learn.algorithms.base_learner import BaseLearner -from learn2learn.utils import clone_module +from learn2learn.utils import clone_module, update_module def maml_update(model, lr, grads=None): @@ -42,36 +42,13 @@ def maml_update(model, lr, grads=None): msg += str(len(params)) + ' vs ' + str(len(grads)) + ')' print(msg) for p, g in zip(params, grads): - p.grad = g - - # Update the params - for param_key in model._parameters: - p = model._parameters[param_key] - if p is not None and p.grad is not None: - model._parameters[param_key] = p - lr * p.grad - - # Second, handle the buffers if necessary - for buffer_key in model._buffers: - buff = model._buffers[buffer_key] - if buff is not None and buff.grad is not None: - model._buffers[buffer_key] = buff - lr * buff.grad - - # Then, recurse for each submodule - for module_key in model._modules: - model._modules[module_key] = maml_update(model._modules[module_key], - lr=lr, - grads=None) - - # Finally, rebuild the flattened parameters for RNNs - # See this issue for more details: - # https://github.com/learnables/learn2learn/issues/139 - model._apply(lambda x: x) - return model + if g is not None: + p.update = - lr * g + return update_module(model) class MAML(BaseLearner): """ - [[Source]](https://github.com/learnables/learn2learn/blob/master/learn2learn/algorithms/maml.py) **Description** diff --git a/learn2learn/data/__init__.py b/learn2learn/data/__init__.py index b34d6207..00ce59ec 100644 --- a/learn2learn/data/__init__.py +++ b/learn2learn/data/__init__.py @@ -1,5 +1,9 @@ #!/usr/bin/env python3 +r""" +A set of utilities for data & tasks loading, preprocessing, and sampling. +""" + from . import transforms from .meta_dataset import MetaDataset from .task_dataset import TaskDataset, DataDescription diff --git a/learn2learn/gym/__init__.py b/learn2learn/gym/__init__.py index baee5f6e..1612e3d8 100644 --- a/learn2learn/gym/__init__.py +++ b/learn2learn/gym/__init__.py @@ -1,5 +1,9 @@ #!/usr/bin/env python3 +r""" +Environment, models, and other utilities related to reinforcement learning and OpenAI Gym. +""" + from . import envs from .envs.meta_env import MetaEnv from .async_vec_env import AsyncVectorEnv diff --git a/learn2learn/gym/envs/metaworld/metaworld.py b/learn2learn/gym/envs/metaworld/metaworld.py index 1e6af9a2..ccf6f4cd 100644 --- a/learn2learn/gym/envs/metaworld/metaworld.py +++ b/learn2learn/gym/envs/metaworld/metaworld.py @@ -5,9 +5,18 @@ try: from metaworld.envs.mujoco.multitask_env import MultiClassMultiTaskEnv from metaworld.benchmarks import ML1, ML10, ML45 -except DependencyNotInstalled: +except (DependencyNotInstalled, ModuleNotFoundError): from learn2learn.gym.envs.mujoco.dummy_mujoco_env import MujocoEnv as MultiClassMultiTaskEnv + class ML1: + pass + + class ML10: + pass + + class ML45: + pass + from learn2learn.gym.envs.meta_env import MetaEnv diff --git a/learn2learn/nn/__init__.py b/learn2learn/nn/__init__.py new file mode 100644 index 00000000..512b354b --- /dev/null +++ b/learn2learn/nn/__init__.py @@ -0,0 +1,8 @@ +#!/usr/bin/env python3 + +r""" +Additional `torch.nn.Module`s frequently used for meta-learning. +""" + +from .kroneckers import * +from .misc import * diff --git a/learn2learn/nn/kroneckers.py b/learn2learn/nn/kroneckers.py new file mode 100644 index 00000000..e437cf32 --- /dev/null +++ b/learn2learn/nn/kroneckers.py @@ -0,0 +1,196 @@ +#!/usr/bin/env python3 + +import torch as th +from torch import nn + + +def kronecker_addmm(mat1, mat2, mat3, bias=None, alpha=1.0, beta=1.0): + """ + Returns alpha * (mat2.t() X mat1) @ vec(mat3) + beta * vec(bias) + (Assuming bias is not None.) + """ + res = mat1 @ mat3 @ mat2 + res.mul_(alpha) + if bias is not None: + res.add_(beta, bias) + return res + + +class KroneckerLinear(nn.Module): + + r""" + [[Source]](https://github.com/learnables/learn2learn/blob/master/learn2learn/nn/kroneckers.py) + + **Description** + + A linear transformation whose parameters are expressed as a Kronecker product. + + This Module maps an input vector \(x \in \mathbb{R}^{nm} \) to \(y = Ax + b\) such that: + + \[ + A = R^\top \otimes L, + \] + + where \(L \in \mathbb{R}^{n \times n}\) and \(R \in \mathbb{R}^{m \times m}\) are the learnable Kronecker factors. + This implementation can reduce the memory requirement for large linear mapping + from \(\mathcal{O}(n^2 \cdot m^2)\) to \(\mathcal{O}(n^2 + m^2)\), but forces \(y \in \mathbb{R}^{nm}\). + + The matrix \(A\) is initialized as the identity, and the bias as a zero vector. + + **Arguments** + + * **n** (int) - Dimensionality of the left Kronecker factor. + * **m** (int) - Dimensionality of the right Kronecker factor. + * **bias** (bool, *optional*, default=True) - Whether to include the bias term. + * **psd** (bool, *optional*, default=False) - Forces the matrix \(A\) to be positive semi-definite if True. + * **device** (device, *optional*, default=None) - The device on which to instantiate the Module. + + **References** + + 1. Jose et al. 2018. "Kronecker recurrent units". + 2. Arnold et al. 2019. "When MAML can adapt fast and how to assist when it cannot". + + **Example** + ~~~python + m, n = 2, 3 + x = torch.randn(6) + kronecker = KroneckerLinear(n, m) + y = kronecker(x) + y.shape # (6, ) + ~~~ + """ + + def __init__(self, n, m, bias=True, psd=False, device=None): + super(KroneckerLinear, self).__init__() + self.left = nn.Parameter(th.eye(n, device=device)) + self.right = nn.Parameter(th.eye(m, device=device)) + self.bias = None + self.psd = psd + if bias: + self.bias = nn.Parameter(th.zeros(n, m, device=device)) + self.device = device + self.to(device=device) + + def forward(self, x): + old_device = x.device + if self.device is not None: + x = x.to(self.device) + left = self.left + right = self.right + if self.psd: + left = left.t() @ left + right = right.t() @ right + if len(x.shape) == 1: + shape = x.shape + x = x.view(-1, 1) + x = kronecker_addmm(left, right, x, self.bias) + return x.view(*shape).to(old_device) + x = kronecker_addmm(left, right, x, self.bias) + return x.to(old_device) + + +class KroneckerRNN(nn.Module): + + """ + [[Source]](https://github.com/learnables/learn2learn/blob/master/learn2learn/nn/kroneckers.py) + + **Description** + + Implements a recurrent neural network whose matrices are parameterized via their Kronecker factors. + (See `KroneckerLinear` for details.) + + **Arguments** + + * **n** (int) - Dimensionality of the left Kronecker factor. + * **m** (int) - Dimensionality of the right Kronecker factor. + * **bias** (bool, *optional*, default=True) - Whether to include the bias term. + * **sigma** (callable, *optional*, default=None) - The activation function. + + **References** + + 1. Jose et al. 2018. "Kronecker recurrent units". + + **Example** + ~~~python + m, n = 2, 3 + x = torch.randn(6) + h = torch.randn(6) + kronecker = KroneckerRNN(n, m) + y, new_h = kronecker(x, h) + y.shape # (6, ) + ~~~ + """ + + def __init__(self, n, m, bias=True, sigma=None): + super(KroneckerRNN, self).__init__() + self.W_h = KroneckerLinear(n, m, bias=bias) + self.U_h = KroneckerLinear(n, m, bias=bias) + self.W_y = KroneckerLinear(n, m, bias=bias) + + if sigma is None: + sigma = nn.Tanh() + self.sigma = sigma + + def forward(self, x, hidden): + new_hidden = self.W_h(x) + self.U_h(hidden) + new_hidden = self.sigma(new_hidden) + output = self.W_y(new_hidden) + return output, new_hidden + + +class KroneckerLSTM(nn.Module): + + """ + [[Source]](https://github.com/learnables/learn2learn/blob/master/learn2learn/nn/kroneckers.py) + + **Description** + + Implements an LSTM using a factorization similar to the one of + `KroneckerLinear`. + + **Arguments** + + * **n** (int) - Dimensionality of the left Kronecker factor. + * **m** (int) - Dimensionality of the right Kronecker factor. + * **bias** (bool, *optional*, default=True) - Whether to include the bias term. + * **sigma** (callable, *optional*, default=None) - The activation function. + + **References** + + 1. Jose et al. 2018. "Kronecker recurrent units". + + **Example** + ~~~python + m, n = 2, 3 + x = torch.randn(6) + h = torch.randn(6) + kronecker = KroneckerLSTM(n, m) + y, new_h = kronecker(x, h) + y.shape # (6, ) + ~~~ + """ + + def __init__(self, n, m, bias=True, sigma=None): + super(KroneckerLSTM, self).__init__() + self.W_ii = KroneckerLinear(n, m, bias=bias) + self.W_hi = KroneckerLinear(n, m, bias=bias) + self.W_if = KroneckerLinear(n, m, bias=bias) + self.W_hf = KroneckerLinear(n, m, bias=bias) + self.W_ig = KroneckerLinear(n, m, bias=bias) + self.W_hg = KroneckerLinear(n, m, bias=bias) + self.W_io = KroneckerLinear(n, m, bias=bias) + self.W_ho = KroneckerLinear(n, m, bias=bias) + if sigma is None: + sigma = nn.Sigmoid() + self.sigma = sigma + self.tanh = nn.Tanh() + + def forward(self, x, hidden): + h, c = hidden + i = self.sigma(self.W_ii(x) + self.W_hi(h)) + f = self.sigma(self.W_if(x) + self.W_hf(h)) + g = self.tanh(self.W_ig(x) + self.W_hg(h)) + o = self.sigma(self.W_io(x) + self.W_ho(h)) + c = f * c + i * g + h = o * self.tanh(c) + return h, (h, c) diff --git a/learn2learn/nn/misc.py b/learn2learn/nn/misc.py new file mode 100644 index 00000000..f23344f3 --- /dev/null +++ b/learn2learn/nn/misc.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python3 + +import torch + + +class Lambda(torch.nn.Module): + """ + [[Source]](https://github.com/learnables/learn2learn/blob/master/learn2learn/nn/misc.py) + + **Description** + + Utility class to create a wrapper based on a lambda function. + + **Arguments** + + * **lmb** (callable) - The function to call in the forward pass. + + **Example** + ~~~python + mean23 = Lambda(lambda x: x.mean(dim=[2, 3])) # mean23 is a Module + x = features(img) + x = mean23(x) + x = x.flatten() + ~~~ + """ + + def __init__(self, lmb): + super(Lambda, self).__init__() + self.lmb = lmb + + def forward(self, *args, **kwargs): + return self.lmb(*args, **kwargs) + + +class Flatten(torch.nn.Module): + """ + [[Source]](https://github.com/learnables/learn2learn/blob/master/learn2learn/nn/misc.py) + + **Description** + + Utility Module to flatten inputs to `(batch_size, -1)` shape. + + **Example** + ~~~python + flatten = Flatten() + x = torch.randn(5, 3, 32, 32) + x = flatten(x) + print(x.shape) # (5, 3072) + ~~~ + """ + + def forward(self, x): + return x.view(x.size(0), -1) + + +class Scale(torch.nn.Module): + """ + [[Source]](https://github.com/learnables/learn2learn/blob/master/learn2learn/nn/misc.py) + + **Description** + + A per-parameter scaling factor with learnable parameter. + + **Arguments** + + * **shape** (int or tuple) - The shape of the scaling matrix. + * **alpha** (float, *optional*, default=1.0) - Initial value for the + scaling factor. + + **Example** + ~~~python + x = torch.ones(3) + scale = Scale(x.shape, alpha=0.5) + print(scale(x)) # [.5, .5, .5] + ~~~ + """ + + def __init__(self, shape, alpha=1.0): + super(Scale, self).__init__() + if isinstance(shape, int): + shape = (shape, ) + alpha = torch.ones(**shape) + self.alpha = torch.nn.Parameter(alpha) + + def forward(self, x): + return x * self.alpha diff --git a/learn2learn/optim/__init__.py b/learn2learn/optim/__init__.py new file mode 100644 index 00000000..48708989 --- /dev/null +++ b/learn2learn/optim/__init__.py @@ -0,0 +1,10 @@ +#!/usr/bin/env python3 + +""" +A set of utilities to write differentiable optimization algorithms. +""" + +from .parameter_update import ParameterUpdate +from .learnable_optimizer import LearnableOptimizer +from .update_rules import * +from . import transforms diff --git a/learn2learn/optim/learnable_optimizer.py b/learn2learn/optim/learnable_optimizer.py new file mode 100644 index 00000000..1875dfee --- /dev/null +++ b/learn2learn/optim/learnable_optimizer.py @@ -0,0 +1,107 @@ +#!/usr/bin/env python3 + +import torch +import learn2learn as l2l +import warnings + + +class LearnableOptimizer(torch.nn.Module): + """ + [[Source]](https://github.com/learnables/learn2learn/blob/master/learn2learn/optim/learnable_optimizer.py) + + **Description** + + A PyTorch Optimizer with learnable transform, enabling the implementation + of meta-descent / hyper-gradient algorithms. + + This optimizer takes a Module and a gradient transform. + At each step, the gradient of the module is passed through the transforms, + and the module differentiably update -- i.e. when the next backward is called, + gradients of both the module and the transform are computed. + In turn, the transform can be updated via your favorite optmizer. + + **Arguments** + + * **model** (Module) - Module to be updated. + * **transform** (Module) - Transform used to compute updates of the model. + * **lr** (float) - Learning rate. + + **References** + + 1. Sutton. 1992. “Gain Adaptation Beats Least Squares.” + 2. Schraudolph. 1999. “Local Gain Adaptation in Stochastic Gradient Descent.” + 3. Baydin et al. 2017. “Online Learning Rate Adaptation with Hypergradient Descent.” + 4. Majumder et al. 2019. “Learning the Learning Rate for Gradient Descent by Gradient Descent.” + 5. Jacobsen et al. 2019. “Meta-Descent for Online, Continual Prediction.” + + **Example** + + ~~~python + linear = nn.Linear(784, 10) + transform = l2l.optim.ModuleTransform(torch.nn.Linear) + metaopt = l2l.optim.LearnableOptimizer(linear, transform, lr=0.01) + opt = torch.optim.SGD(metaopt.parameters(), lr=0.001) + + metaopt.zero_grad() + opt.zero_grad() + error = loss(linear(X), y) + error.backward() + opt.step() # update metaopt + metaopt.step() # update linear + ~~~ + """ + + def __init__(self, model, transform, lr=1.0): + super(LearnableOptimizer, self).__init__() + assert isinstance(model, torch.nn.Module), \ + 'model should inherit from nn.Module.' + + # Keep pointer to model, but don't include in self._modules, + # self._children, or self._parameters + self.info = { + 'model': model, + } + + # Create the transforms + self.transforms = [] + for name, param in model.named_parameters(): + trans = transform(param) + self.transforms.append(trans) + self.transforms = torch.nn.ModuleList(self.transforms) + self.lr = lr + + def step(self, closure=None): + # TODO: Do we need to recompute flat_grads for RNNs ? - Write a test. + model = self.info['model'] + # Ignore warnings as torch 1.5+ warns about accessing .grad of non-leaf + # variables. + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + for param, transform in zip(model.parameters(), + self.transforms): + if hasattr(param, 'grad') and param.grad is not None: + # 1. compute update + grad = param.grad.detach() + grad.requires_grad = False + update = - self.lr * transform(grad) + + # 2. detach parameters + param.detach_() + param.requires_grad = False + param.update = update + + # 3. apply update so that it's differentiable + l2l.update_module(model, updates=None) + + for param in model.parameters(): + # 4. retain grad for next update + param.retain_grad() + + def zero_grad(self): + """Only reset target parameters.""" + model = self.info['model'] + for p in model.parameters(): + if hasattr(p, 'grad') and p.grad is not None: + # Do not reset in-place: + # it breaks the computation graph of step(). + p.grad = torch.zeros_like(p.data) diff --git a/learn2learn/optim/parameter_update.py b/learn2learn/optim/parameter_update.py new file mode 100644 index 00000000..3ad46e83 --- /dev/null +++ b/learn2learn/optim/parameter_update.py @@ -0,0 +1,138 @@ +#!/usr/bin/env python3 + +import torch +import traceback + + +class ParameterUpdate(torch.nn.Module): + """ + [[Source]](https://github.com/learnables/learn2learn/blob/master/learn2learn/optim/parameter_update.py) + + **Description** + + Convenience class to implement custom update functions. + + Objects instantiated from this class behave similarly to `torch.autograd.grad`, + but return parameter updates as opposed to gradients. + Concretely, the gradients are first computed, then fed to their respective transform + whose output is finally returned to the user. + + Additionally, this class supports parameters that might not require updates by setting + the `allow_nograd` flag to True. + In this case, the returned update is `None`. + + **Arguments** + + * **parameters** (list) - Parameters of the model to update. + * **transform** (callable) - A callable that returns an instantiated + transform given a parameter. + + **Example** + ~~~python + model = torch.nn.Linear() + transform = l2l.optim.KroneckerTransform(l2l.nn.KroneckerLinear) + get_update = ParameterUpdate(model, transform) + opt = torch.optim.SGD(model.parameters() + get_update.parameters()) + + for iteration in range(10): + opt.zero_grad() + error = loss(model(X), y) + updates = get_update( + error, + model.parameters(), + create_graph=True, + ) + l2l.update_module(model, updates) + opt.step() + ~~~ + """ + + def __init__(self, parameters, transform): + super(ParameterUpdate, self).__init__() + transforms_indices = [] + transform_modules = [] + module_counter = 0 + for param in parameters: + t = transform(param) + if t is None: + idx = None + elif isinstance(t, torch.nn.Module): + transform_modules.append(t) + idx = module_counter + module_counter += 1 + else: + msg = 'Transform should be either a Module or None.' + raise ValueError(msg) + transforms_indices.append(idx) + self.transforms_modules = torch.nn.ModuleList(transform_modules) + self.transforms_indices = transforms_indices + + def forward( + self, + loss, + parameters, + create_graph=False, + retain_graph=False, + allow_unused=False, + allow_nograd=False, + ): + """ + **Description** + + Similar to torch.autograd.grad, but passes the gradients through the + provided transform. + + **Arguments** + + * **loss** (Tensor) - The loss to differentiate. + * **parameters** (iterable) - Parameters w.r.t. which we want to compute the update. + * **create_graph** (bool, *optional*, default=False) - Same as `torch.autograd.grad`. + * **retain_graph** (bool, *optional*, default=False) - Same as `torch.autograd.grad`. + * **allow_unused** (bool, *optional*, default=False) - Same as `torch.autograd.grad`. + * **allow_nograd** (bool, *optional*, default=False) - Properly handles parameters + that do not require gradients. (Their update will be `None`.) + + """ + updates = [] + if allow_nograd: + parameters = list(parameters) + diff_params = [p for p in parameters if p.requires_grad] + grad_params = torch.autograd.grad( + loss, + diff_params, + retain_graph=create_graph, + create_graph=create_graph, + allow_unused=allow_unused) + gradients = [] + + # Handles gradients for non-differentiable parameters + grad_counter = 0 + for param in parameters: + if param.requires_grad: + gradient = grad_params[grad_counter] + grad_counter += 1 + else: + gradient = None + gradients.append(gradient) + else: + try: + gradients = torch.autograd.grad( + loss, + parameters, + create_graph=create_graph, + retain_graph=retain_graph, + allow_unused=allow_unused, + ) + except RuntimeError: + traceback.print_exc() + msg = 'learn2learn: Maybe try with allow_nograd=True and/or' +\ + 'allow_unused=True ?' + print(msg) + for g, t in zip(gradients, self.transforms_indices): + if t is None or g is None: + update = g + else: + transform = self.transforms_modules[t] + update = transform(g) + updates.append(update) + return updates diff --git a/learn2learn/optim/transforms/__init__.py b/learn2learn/optim/transforms/__init__.py new file mode 100644 index 00000000..8d62b45d --- /dev/null +++ b/learn2learn/optim/transforms/__init__.py @@ -0,0 +1,13 @@ +#!/usr/bin/env python3 + +""" +Optimization transforms are special modules that take gradients as inputs +and output model updates. +Transforms are usually parameterized, and those parameters can be learned by +gradient descent, allow you to learn optimization functions from data. +""" + +from .module_transform import ModuleTransform, ReshapedTransform +from .kronecker_transform import KroneckerTransform +from .transform_dictionary import TransformDictionary +from .metacurvature_transform import MetaCurvatureTransform diff --git a/learn2learn/optim/transforms/kronecker_transform.py b/learn2learn/optim/transforms/kronecker_transform.py new file mode 100644 index 00000000..8e6cb845 --- /dev/null +++ b/learn2learn/optim/transforms/kronecker_transform.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python3 + +import learn2learn as l2l + + +def get_kronecker_dims(param): + shape = param.shape + if len(shape) == 2: # FC + n, m = shape + elif len(shape) == 1: # Bias + n, m = shape[0], 1 + elif len(shape) == 4: # CNN + n = shape[1] + m = shape[2] * shape[3] + else: + raise NotImplementedError('Layer not supported. Please open an issue.') + return n, m + + +class KroneckerTransform(object): + + """ + [[Source]](https://github.com/learnables/learn2learn/blob/master/learn2learn/optim/transforms/module_transform.py) + + **Description** + + The KroneckerTransform creates a an optimization transform based on nn.Module's that admit a Kronecker factorization. + (see `l2l.nn.Kronecker*`) + + Akin to the ModuleTransform, this class of transform instanciates a module from its class, based on a given parameter. + But, instead of reshaping the gradients to shape `(1, param.numel())`, this class assumes a Kronecker factorization + of the weights for memory and computational efficiency. + + The specific dimension of the Kronecker factorization depends on the the parameter's shape. + For a weight of shape (n, m), a KroneckerLinear transform consists of two weights with shapes (n, n) and (m, m) rather + than a single weight of shape (nm, nm). + Refer to Arnold et al., 2019 for more details. + + **Arguments** + + * **kronecker_cls** (callable) - A callable that instantiates the Kronecker module used to transform gradients. + + **References** + + 1. Arnold et al. 2019. "When MAML can adapt fast and how to assist when it cannot". + + **Example** + ~~~python + classifier = torch.nn.Linear(784, 10, bias=False) + kronecker_transform = KroneckerTransform(l2l.nn.KroneckerLinear) + kronecker_update = kronecker_transform(classifier.weight) + loss(classifier(X), y).backward() + update = kronecker_update(classifier.weight.grad) + classifier.weight.data.add_(-lr, update) # Not a differentiable update. See l2l.optim.DifferentiableSGD. + ~~~ + """ + + def __init__(self, kronecker_cls, bias=False, psd=True): + self.kronecker_cls = kronecker_cls + self.bias = bias + self.psd = psd + + def __call__(self, param): + """docstring for forward""" + n, m = get_kronecker_dims(param) + transform = self.kronecker_cls( + n=n, + m=m, + bias=self.bias, + psd=self.psd, + ) + return l2l.optim.transforms.ReshapedTransform( + transform=transform, + shape=(-1, n, m) + ) diff --git a/learn2learn/optim/transforms/metacurvature_transform.py b/learn2learn/optim/transforms/metacurvature_transform.py new file mode 100644 index 00000000..cf02d0c3 --- /dev/null +++ b/learn2learn/optim/transforms/metacurvature_transform.py @@ -0,0 +1,82 @@ +#!/usr/bin/env python3 + +import torch +import numpy as np + + +class MetaCurvatureTransform(torch.nn.Module): + + """ + [[Source]](https://github.com/learnables/learn2learn/blob/master/learn2learn/optim/transforms/module_transform.py) + + **Description** + + Implements the Meta-Curvature transform of Park and Oliva, 2019. + + Unlike `ModuleTranform` and `KroneckerTransform`, this class does not wrap other Modules but is directly + called on a weight to instantiate the transform. + + **Arguments** + + * **param** (Tensor) - The weight whose gradients will be transformed. + * **lr** (float, *optional*, default=1.0) - Scaling factor of the udpate. (non-learnable) + + **References** + + 1. Park & Oliva. 2019. Meta-curvature. + + **Example** + ~~~python + classifier = torch.nn.Linear(784, 10, bias=False) + metacurvature_update = MetaCurvatureTransform(classifier.weight) + loss(classifier(X), y).backward() + update = metacurvature_update(classifier.weight.grad) + classifier.weight.data.add_(-lr, update) # Not a differentiable update. See l2l.optim.DifferentiableSGD. + ~~~ + """ + + def __init__(self, param, lr=1.0): + super(MetaCurvatureTransform, self).__init__() + self.lr = lr + shape = param.shape + if len(shape) == 1: # bias + self.dim = 1 + self.mc = torch.nn.Parameter(torch.ones_like(param)) + elif len(shape) == 2: # FC + self.dim = 2 + self.mc_in = torch.nn.Parameter(torch.eye(shape[0])) + self.mc_out = torch.nn.Parameter(torch.eye(shape[1])) + elif len(shape) == 4: # CNN + self.dim = 4 + self.n_in = shape[0] + self.n_out = shape[1] + self.n_f = int(np.prod(shape) / (self.n_in * self.n_out)) + self.mc_in = torch.nn.Parameter(torch.eye(self.n_in)) + self.mc_out = torch.nn.Parameter(torch.eye(self.n_out)) + self.mc_f = torch.nn.Parameter(torch.eye(self.n_f)) + else: + raise NotImplementedError('Parameter with shape', + shape, + 'is not supported by MetaCurvature.') + + def forward(self, grad): + if self.dim == 1: + update = self.mc * grad + elif self.dim == 2: + update = self.mc_in @ grad @ self.mc_out + else: + # Following the ref. implementation, we use TensorFlow's shapes + # TODO: Rewrite for PyTorch's conv and avoid contiguous()/permute() + update = grad.permute(2, 3, 0, 1).contiguous() + shape = update.shape + update = update.view(-1, self.n_out) @ self.mc_out + update = self.mc_f @ update.view(self.n_f, -1) + update = update.view(self.n_f, self.n_in, self.n_out) + update = update.permute(1, 0, 2).contiguous().view(self.n_in, -1) + update = self.mc_in @ update + update = update.view( + self.n_in, + self.n_f, + self.n_out).permute(1, 0, 2).contiguous().view(shape) + update = update.permute(2, 3, 0, 1).contiguous() + return self.lr * update diff --git a/learn2learn/optim/transforms/module_transform.py b/learn2learn/optim/transforms/module_transform.py new file mode 100644 index 00000000..c5454f32 --- /dev/null +++ b/learn2learn/optim/transforms/module_transform.py @@ -0,0 +1,66 @@ +#!/usr/bin/env python3 + +import torch + + +class ReshapedTransform(torch.nn.Module): + """ + Helper class to reshape gradients before they are fed to a Module and + reshape back the update returned by the Module. + """ + + def __init__(self, transform, shape): + super(ReshapedTransform, self).__init__() + self.transform = transform + self.in_shape = shape + + def forward(self, grad): + """docstring for __forward__""" + out_shape = grad.shape + update = grad.view(self.in_shape) + update = self.transform(update) + update = update.view(out_shape) + return update + + +class ModuleTransform(object): + + """ + [[Source]](https://github.com/learnables/learn2learn/blob/master/learn2learn/optim/transforms/module_transform.py) + + **Description** + + The ModuleTransform creates a an optimization transform based on any nn.Module. + + ModuleTransform automatically instanciates a module from its class, based on a given parameter. + The input and output shapes are of the module are set to `(1, param.numel())`. + + When optimizing large layers, this type of transform can quickly run out of memory. + See `KroneckerTransform` for a scalable alternative. + + **Arguments** + + * **module_cls** (callable) - A callable that instantiates the module used to transform gradients. + + **Example** + ~~~python + classifier = torch.nn.Linear(784, 10, bias=False) + linear_transform = ModuleTransform(torch.nn.Linear) + linear_update = linear_transform(classifier.weight) # maps gradients to updates, both of shape (1, 7840) + loss(classifier(X), y).backward() + update = linear_update(classifier.weight.grad) + classifier.weight.data.add_(-lr, update) # Not a differentiable update. See l2l.optim.DifferentiableSGD. + ~~~ + """ + + def __init__(self, module_cls): + self.module_cls = module_cls + + def __call__(self, parameter): + numel = parameter.numel() + flat_shape = (1, numel) + transform = self.module_cls(numel, numel) + return ReshapedTransform( + transform=transform, + shape=flat_shape, + ) diff --git a/learn2learn/optim/transforms/transform_dictionary.py b/learn2learn/optim/transforms/transform_dictionary.py new file mode 100644 index 00000000..5e766122 --- /dev/null +++ b/learn2learn/optim/transforms/transform_dictionary.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python3 + +import torch + + +class TransformDictionary(object): + """docstring for ModuleDependentTransform""" + + def __init__(self, dictionary): + self.param_to_transform = {} + for key, transform in dictionary.items(): + if isinstance(key, torch.nn.Module): + for p in key.parameters(): + self.param_to_transform[p] = transform + elif isinstance(key, torch.nn.Parameter): + self.param_to_transform[key] = transform + else: + raise ValueError( + 'TransformDictionary only accepts Modules' + + ' or Parameters as dictionary keys.') + + def __call__(self, param, *args, **kwargs): + if param in self.param_to_transform: + return self.param_to_transform[param](param, *args, **kwargs) + return None diff --git a/learn2learn/optim/update_rules/__init__.py b/learn2learn/optim/update_rules/__init__.py new file mode 100644 index 00000000..4a1a8d01 --- /dev/null +++ b/learn2learn/optim/update_rules/__init__.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python3 + +from .differentiable_sgd import DifferentiableSGD diff --git a/learn2learn/optim/update_rules/differentiable_sgd.py b/learn2learn/optim/update_rules/differentiable_sgd.py new file mode 100644 index 00000000..844917fb --- /dev/null +++ b/learn2learn/optim/update_rules/differentiable_sgd.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python3 + +import torch +from learn2learn.utils import update_module + + +class DifferentiableSGD(torch.nn.Module): + r""" + [[Source]](https://github.com/learnables/learn2learn/blob/master/learn2learn/optim/update_rules/differentiable_sgd.py) + + **Description** + + A callable object that applies a list of updates to the parameters of a torch.nn.Module in a differentiable manner. + + For each parameter \(p\) and corresponding gradient \(g\), calling an instance of this class results in updating parameters: + + \[ + p \gets p - \alpha g, + \] + + where \(\alpha\) is the learning rate. + + Note: The module is updated in-place. + + **Arguments** + + * **lr** (float) - The learning rate used to update the model. + + **Example** + ~~~python + sgd = DifferentiableSGD(0.1) + gradients = torch.autograd.grad( + loss, + model.parameters(), + create_gaph=True) + sgd(model, gradients) # model is updated in-place + ~~~ + """ + + def __init__(self, lr): + super(DifferentiableSGD, self).__init__() + self.lr = lr + + def forward(self, module, gradients=None): + """ + **Arguments** + + * **module** (Module) - The module to update. + * **gradients** (list, *optional*, default=None) - A list of gradients for each parameter + of the module. If None, will use the gradients in .grad attributes. + + """ + if gradients is None: + gradients = [p.grad for p in module.parameters()] + updates = [None if g is None else g.mul(-self.lr) + for g in gradients] + update_module(module, updates) diff --git a/learn2learn/utils.py b/learn2learn/utils.py index c9dc7795..b5ab8dbf 100644 --- a/learn2learn/utils.py +++ b/learn2learn/utils.py @@ -1,7 +1,6 @@ #!/usr/bin/env python3 import copy - import torch @@ -178,12 +177,12 @@ def clone_distribution(dist): for param_key in clone.__dict__: item = clone.__dict__[param_key] - if isinstance(item, th.Tensor): + if isinstance(item, torch.Tensor): if item.requires_grad: clone.__dict__[param_key] = dist.__dict__[param_key].clone() - elif isinstance(item, th.nn.Module): + elif isinstance(item, torch.nn.Module): clone.__dict__[param_key] = clone_module(dist.__dict__[param_key]) - elif isinstance(item, th.Distribution): + elif isinstance(item, torch.Distribution): clone.__dict__[param_key] = clone_distribution(dist.__dict__[param_key]) return clone @@ -193,11 +192,77 @@ def detach_distribution(dist): # TODO: This function was never tested. for param_key in dist.__dict__: item = dist.__dict__[param_key] - if isinstance(item, th.Tensor): + if isinstance(item, torch.Tensor): if item.requires_grad: dist.__dict__[param_key] = dist.__dict__[param_key].detach() - elif isinstance(item, th.nn.Module): + elif isinstance(item, torch.nn.Module): dist.__dict__[param_key] = detach_module(dist.__dict__[param_key]) - elif isinstance(item, th.Distribution): + elif isinstance(item, torch.Distribution): dist.__dict__[param_key] = detach_distribution(dist.__dict__[param_key]) return dist + + +def update_module(module, updates=None): + r""" + [[Source]](https://github.com/learnables/learn2learn/blob/master/learn2learn/utils.py) + + **Description** + + Updates the parameters of a module in-place, in a way that preserves differentiability. + + The parameters of the module are swapped with their update values, according to: + \[ + p \gets p + u, + \] + where \(p\) is the parameter, and \(u\) is its corresponding update. + + + **Arguments** + + * **module** (Module) - The module to update. + * **updates** (list, *optional*, default=None) - A list of gradients for each parameter + of the model. If None, will use the tensors in .update attributes. + + **Example** + ~~~python + error = loss(model(X), y) + grads = torch.autograd.grad( + error, + model.parameters(), + create_graph=True, + ) + updates = [-lr * g for g in grads] + l2l.update_module(model, updates=updates) + ~~~ + """ + if updates is not None: + params = list(module.parameters()) + if not len(updates) == len(list(params)): + msg = 'WARNING:update_module(): Parameters and updates have different length. (' + msg += str(len(params)) + ' vs ' + str(len(updates)) + ')' + print(msg) + for p, g in zip(params, updates): + p.update = g + + # Update the params + for param_key in module._parameters: + p = module._parameters[param_key] + if p is not None and hasattr(p, 'update') and p.update is not None: + module._parameters[param_key] = p + p.update + + # Second, handle the buffers if necessary + for buffer_key in module._buffers: + buff = module._buffers[buffer_key] + if buff is not None and hasattr(buff, 'update') and buff.update is not None: + module._buffers[buffer_key] = buff + buff.update + + # Then, recurse for each submodule + for module_key in module._modules: + module._modules[module_key] = update_module(module._modules[module_key], + updates=None) + + # Finally, rebuild the flattened parameters for RNNs + # See this issue for more details: + # https://github.com/learnables/learn2learn/issues/139 + module._apply(lambda x: x) + return module diff --git a/learn2learn/vision/__init__.py b/learn2learn/vision/__init__.py index d6d3c699..f4e47ee7 100644 --- a/learn2learn/vision/__init__.py +++ b/learn2learn/vision/__init__.py @@ -1,5 +1,9 @@ #!/usr/bin/env python3 +r""" +Datasets, models, and other utilities related to computer vision. +""" + from . import datasets from . import models from . import transforms diff --git a/tests/integration/maml_miniimagenet_test_notravis.py b/tests/integration/maml_miniimagenet_test_notravis.py index 92ddcfd1..bd0cab74 100644 --- a/tests/integration/maml_miniimagenet_test_notravis.py +++ b/tests/integration/maml_miniimagenet_test_notravis.py @@ -62,9 +62,9 @@ def main( device = torch.device('cuda') # Create Datasets - train_dataset = l2l.vision.datasets.MiniImagenet(root='./data', mode='train') - valid_dataset = l2l.vision.datasets.MiniImagenet(root='./data', mode='validation') - test_dataset = l2l.vision.datasets.MiniImagenet(root='./data', mode='test') + train_dataset = l2l.vision.datasets.MiniImagenet(root='~/data', mode='train') + valid_dataset = l2l.vision.datasets.MiniImagenet(root='~/data', mode='validation') + test_dataset = l2l.vision.datasets.MiniImagenet(root='~/data', mode='test') train_dataset = l2l.data.MetaDataset(train_dataset) valid_dataset = l2l.data.MetaDataset(valid_dataset) test_dataset = l2l.data.MetaDataset(test_dataset) diff --git a/tests/integration/protonets_miniimagenet_test_notravis.py b/tests/integration/protonets_miniimagenet_test_notravis.py index 5eb78fa7..6c975b2f 100644 --- a/tests/integration/protonets_miniimagenet_test_notravis.py +++ b/tests/integration/protonets_miniimagenet_test_notravis.py @@ -115,7 +115,7 @@ def main(num_iterations=250): model = Convnet() model.to(device) - path_data = './data' + path_data = '~/data' train_dataset = l2l.vision.datasets.MiniImagenet( root=path_data, mode='train') valid_dataset = l2l.vision.datasets.MiniImagenet( diff --git a/tests/unit/algorithms/gbml_test.py b/tests/unit/algorithms/gbml_test.py new file mode 100644 index 00000000..a73dc928 --- /dev/null +++ b/tests/unit/algorithms/gbml_test.py @@ -0,0 +1,151 @@ +#!/usr/bin/env python3 + +import unittest +import torch +import learn2learn as l2l + +NUM_INPUTS = 7 +INPUT_SIZE = 10 +HIDDEN_SIZE = 20 +INNER_LR = 0.01 +EPSILON = 1e-8 + + +class LR(torch.nn.Module): + + def __init__(self, input_size, output_size): + super(LR, self).__init__() + self.lr = torch.ones(input_size) + self.lr = torch.nn.Parameter(self.lr) + + def forward(self, grad): + return self.lr * grad + + +def close(x, y): + return (x - y).norm(p=2) <= EPSILON + + +class TestGBMLgorithm(unittest.TestCase): + + def setUp(self): + self.model = torch.nn.Sequential( + torch.nn.Linear(INPUT_SIZE, HIDDEN_SIZE), + torch.nn.ReLU(), + torch.nn.Linear(HIDDEN_SIZE, HIDDEN_SIZE), + torch.nn.Sigmoid(), + torch.nn.Linear(HIDDEN_SIZE, HIDDEN_SIZE), + torch.nn.Softmax()) + + self.model.register_buffer('dummy_buf', torch.zeros(1, 2, 3, 4)) + + def tearDown(self): + pass + + def test_clone_module(self): + for first_order in [False, True]: + transform = l2l.optim.transforms.ModuleTransform(LR) + gbml = l2l.algorithms.GBML(self.model, + transform=transform, + first_order=first_order, + lr=INNER_LR) + X = torch.randn(NUM_INPUTS, INPUT_SIZE) + ref = self.model(X) + for clone in [gbml.clone(), gbml.clone()]: + out = clone(X) + self.assertTrue(close(ref, out)) + + def test_graph_connection(self): + for adapt_transform in [False, True]: + transform = l2l.optim.transforms.ModuleTransform(LR) + gbml = l2l.algorithms.GBML(self.model, + transform=transform, + adapt_transform=adapt_transform, + lr=INNER_LR) + X = torch.randn(NUM_INPUTS, INPUT_SIZE) + ref = gbml(X) + clone = gbml.clone() + out = clone(X) + out.norm(p=2).backward() + for p in self.model.parameters(): + self.assertTrue(hasattr(p, 'grad')) + self.assertTrue(p.grad.norm(p=2).item() > 0.0) + + def test_adaptation(self): + for adapt_transform in [False, True]: + transform = l2l.optim.transforms.ModuleTransform(LR) + gbml = l2l.algorithms.GBML(self.model, + transform=transform, + adapt_transform=adapt_transform, + lr=INNER_LR) + X = torch.randn(NUM_INPUTS, INPUT_SIZE) + clone = gbml.clone() + loss = clone(X).norm(p=2) + clone.adapt(loss) + new_loss = clone(X).norm(p=2) + self.assertTrue(loss >= new_loss) + new_loss.backward() + for p in self.model.parameters(): + self.assertTrue(hasattr(p, 'grad')) + self.assertTrue(p.grad.norm(p=2).item() > 0.0) + + def test_allow_unused(self): + transform = l2l.optim.transforms.ModuleTransform(LR) + gbml = l2l.algorithms.GBML(self.model, + transform=transform, + lr=INNER_LR, + allow_unused=True) + clone = gbml.clone() + loss = 0.0 + for i, p in enumerate(clone.parameters()): + if i % 2 == 0: + loss += p.norm(p=2) + clone.adapt(loss) + loss = 0.0 + for i, p in enumerate(clone.parameters()): + if i % 2 == 0: + loss += p.norm(p=2) + loss.backward() + for p in gbml.parameters(): + self.assertTrue(hasattr(p, 'grad')) + + def test_allow_nograd(self): + self.model[2].weight.requires_grad = False + transform = l2l.optim.transforms.ModuleTransform(LR) + gbml = l2l.algorithms.GBML(self.model, + transform=transform, + lr=INNER_LR, + allow_unused=False, + allow_nograd=False) + clone = gbml.clone() + + loss = sum([p.norm(p=2) for p in clone.parameters()]) + try: + # Check that without allow_nograd, adaptation fails + clone.adapt(loss) + self.assertTrue(False, 'adaptation successful despite requires_grad=False') # Check that execution never gets here + except: + # Check that with allow_nograd, adaptation succeeds + clone.adapt(loss, allow_nograd=True) + loss = sum([p.norm(p=2) for p in clone.parameters()]) + loss.backward() + self.assertTrue(self.model[2].weight.grad is None) + for p in self.model.parameters(): + if p.requires_grad: + self.assertTrue(p.grad is not None) + + transform = l2l.optim.transforms.ModuleTransform(LR) + gbml = l2l.algorithms.GBML(self.model, + transform=transform, + lr=INNER_LR, + allow_nograd=True) + clone = gbml.clone() + loss = sum([p.norm(p=2) for p in clone.parameters()]) + # Check that without allow_nograd, adaptation succeeds thanks to init. + orig_weight = self.model[2].weight.clone().detach() + clone.adapt(loss) + self.assertTrue(close(orig_weight, self.model[2].weight)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/unit/data/util_datasets.py b/tests/unit/data/util_datasets.py index 1a4cd275..44209520 100644 --- a/tests/unit/data/util_datasets.py +++ b/tests/unit/data/util_datasets.py @@ -1,3 +1,5 @@ +#!/usr/bin/env python3 + import string import numpy as np diff --git a/tests/unit/vision/benchmarks_test.py b/tests/unit/vision/benchmarks_test.py index 184482d3..a6ef6853 100644 --- a/tests/unit/vision/benchmarks_test.py +++ b/tests/unit/vision/benchmarks_test.py @@ -21,7 +21,7 @@ def test_tasksets(self): for name in names: if name in TOO_BIG_TO_TEST: continue - tasksets = l2l.vision.benchmarks.get_tasksets(name, root='./data') + tasksets = l2l.vision.benchmarks.get_tasksets(name, root='~/data') self.assertTrue(hasattr(tasksets, 'train')) batch = tasksets.train.sample() self.assertTrue(batch is not None)