Skip to content

Commit

Permalink
fix examples and Flash trainer (#183)
Browse files Browse the repository at this point in the history
* fix examples

* update

* fix

* update
  • Loading branch information
aniketmaurya authored May 18, 2022
1 parent 60ca633 commit 546bad8
Show file tree
Hide file tree
Showing 9 changed files with 102 additions and 43 deletions.
74 changes: 60 additions & 14 deletions examples/nbs/01-ImageClassification.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"metadata": {
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"<!--<badge>--><a href=\"https://colab.research.google.com/github/gradsflow/gradsflow/blob/main/examples/nbs/01-ImageClassification.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a><!--</badge>-->\n",
"\n",
Expand All @@ -15,6 +19,9 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"pycharm": {
"name": "#%%\n"
},
"tags": []
},
"outputs": [],
Expand All @@ -30,7 +37,11 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"import torchvision\n",
Expand All @@ -43,15 +54,23 @@
},
{
"cell_type": "markdown",
"metadata": {},
"metadata": {
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"Let's use `CalTech101` dataset provided by `torchvision`"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [
{
"name": "stdout",
Expand All @@ -75,7 +94,11 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"train_data, val_data = random_split_dataset(data, 0.01)\n",
Expand All @@ -86,15 +109,23 @@
},
{
"cell_type": "markdown",
"metadata": {},
"metadata": {
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"If you want to run Gradsflow on a remote server then first setup [ray cluster](https://docs.ray.io/en/master/cluster/index.html) and initialize ray with the remote address."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"# ray.init(address=\"REMOTE_IP_ADDR\")\n",
Expand All @@ -103,15 +134,23 @@
},
{
"cell_type": "markdown",
"metadata": {},
"metadata": {
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"To train an image classifier create an object of `AutoImageClassifier` and provide number of trials and timeout."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [
{
"name": "stdout",
Expand All @@ -126,18 +165,21 @@
" train_dataloader=train_dl,\n",
" val_dataloader=val_dl,\n",
" num_classes=num_classes,\n",
" max_epochs=5,\n",
" max_epochs=2,\n",
" optimization_metric=\"train_loss\",\n",
" max_steps=1,\n",
" n_trials=1,\n",
" n_trials=4,\n",
")\n",
"print(\"AutoImageClassifier initialised!\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [
{
"name": "stderr",
Expand Down Expand Up @@ -758,7 +800,11 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"# ray.shutdown()"
Expand Down
9 changes: 6 additions & 3 deletions examples/nbs/02-TextClassification.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,11 @@
"cell_type": "code",
"execution_count": null,
"id": "9f5d0474",
"metadata": {},
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"suggested_conf = dict(\n",
Expand All @@ -83,9 +87,8 @@
" datamodule,\n",
" suggested_backbones=[\"sgugger/tiny-distilbert-classification\"],\n",
" suggested_conf=suggested_conf,\n",
" max_epochs=1,\n",
" max_epochs=2,\n",
" optimization_metric=\"val_accuracy\",\n",
" timeout=5,\n",
")\n",
"\n",
"print(\"AutoTextClassifier initialised!\")\n",
Expand Down
2 changes: 1 addition & 1 deletion examples/nbs/03-TextSummarization.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
" suggested_conf=suggested_conf,\n",
" max_epochs=1,\n",
" optimization_metric=\"train_loss\",\n",
" timeout=5,\n",
" timeout=600,\n",
")\n",
"\n",
"print(\"AutoSummarization initialised!\")\n",
Expand Down
14 changes: 3 additions & 11 deletions examples/nbs/04-RayDataset.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "1c550a6f",
"id": "b94e30dd",
"metadata": {
"pycharm": {
"name": "#%%\n"
Expand All @@ -77,21 +77,13 @@
{
"cell_type": "code",
"execution_count": null,
"id": "ef8c7353",
"id": "3822e6bb",
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2021-09-26 01:55:36,868\tINFO services.py:1263 -- View the Ray dashboard at \u001b[1m\u001b[32mhttp://127.0.0.1:8266\u001b[39m\u001b[22m\n"
]
}
],
"outputs": [],
"source": [
"transforms = get_augmentations()\n",
"\n",
Expand Down
8 changes: 6 additions & 2 deletions examples/src/tasks/text_classification.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import ray
from flash.core.data.utils import download_data
from flash.text import TextClassificationData

from gradsflow import AutoTextClassifier

ray.init(address="auto")

download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", "./data/")

print("Creating datamodule...")
Expand All @@ -14,14 +17,15 @@
optimizers=["adam"],
lr=(5e-4, 1e-3),
)

model = AutoTextClassifier(
datamodule,
suggested_backbones=["sgugger/tiny-distilbert-classification"],
suggested_conf=suggested_conf,
max_epochs=1,
max_epochs=2,
optimization_metric="val_accuracy",
timeout=5,
)

print("AutoTextClassifier initialised!")
model.hp_tune()
ray.shutdown()
2 changes: 2 additions & 0 deletions gradsflow/autotasks/autoclassification/text/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def build_model(self, config: dict) -> torch.nn.Module:
learning_rate [float]: Learning rate for the model.
"""
from flash.text.classification import TextClassifier
from torchmetrics import Accuracy

backbone = config["backbone"]
optimizer = config["optimizer"]
Expand All @@ -94,4 +95,5 @@ def build_model(self, config: dict) -> torch.nn.Module:
backbone=backbone,
optimizer=self._OPTIMIZER_INDEX[optimizer],
learning_rate=learning_rate,
metrics=Accuracy(num_classes=self.num_classes),
)
20 changes: 15 additions & 5 deletions gradsflow/autotasks/engine/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import logging
import math
import typing
from enum import Enum
from typing import Callable, Dict, Optional

Expand All @@ -24,10 +25,18 @@
from gradsflow.utility.common import module_to_cls_index
from gradsflow.utility.imports import is_installed

pl = None
if is_installed("pytorch_lightning"):
if typing.TYPE_CHECKING:
import pytorch_lightning as pl


if is_installed("pytorch_lightning"):
from flash import Task
from flash import Trainer as FlashTrainer
from pytorch_lightning import Trainer as PLTrainer
else:
FlashTrainer = None
PLTrainer = None

logger = logging.getLogger("core.backend")


Expand Down Expand Up @@ -83,10 +92,12 @@ def _lightning_objective(
val_check_interval = max(self.max_steps - 1, 1.0)

datamodule = self.autodataset.datamodule
model = self.model_builder(config)

trainer_cls = FlashTrainer if isinstance(model, Task) else PLTrainer

trainer = pl.Trainer(
trainer: "pl.Trainer" = trainer_cls(
logger=True,
checkpoint_callback=False,
gpus=math.ceil(gpu),
max_epochs=self.max_epochs,
max_steps=self.max_steps,
Expand All @@ -95,7 +106,6 @@ def _lightning_objective(
**trainer_config,
)

model = self.model_builder(config)
hparams = dict(model=model.hparams)
trainer.logger.log_hyperparams(hparams)
trainer.fit(model, datamodule=datamodule)
Expand Down
7 changes: 4 additions & 3 deletions tests/autotasks/test_autotrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,16 @@
trainer_config = {"show_progress": False}


@patch("gradsflow.autotasks.engine.backend.pl")
def test_optimization_objective(mock_pl: Mock):
@patch("gradsflow.autotasks.engine.backend.FlashTrainer")
@patch("gradsflow.autotasks.engine.backend.PLTrainer")
def test_optimization_objective(mock_pl_trainer: Mock, mock_fl_trainer: Mock):
dm = MagicMock()
model_builder = MagicMock()

# backend_type is pl
autotrainer = Backend(dm, model_builder, optimization_metric="val_accuracy", backend="pl")
autotrainer.optimization_objective({}, trainer_config)
mock_pl.Trainer.assert_called()
assert mock_pl_trainer.called or mock_fl_trainer.called

# wrong backend_type is passed
with pytest.raises(NotImplementedError):
Expand Down
9 changes: 5 additions & 4 deletions tests/autotasks/test_core_automodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@ def test_create_search_space():


@patch.multiple(AutoModel, __abstractmethods__=set())
@patch("gradsflow.autotasks.engine.backend.pl")
def test_objective(mock_pl):
@patch("gradsflow.autotasks.engine.backend.FlashTrainer")
@patch("gradsflow.autotasks.engine.backend.PLTrainer")
def test_objective(mock_pl_trainer, mock_fl_trainer):
optimization_metric = "val_accuracy"
model = AutoModel(
datamodule,
Expand All @@ -58,8 +59,8 @@ def test_objective(mock_pl):
)

model.backend.model_builder = MagicMock()
trainer = mock_pl.Trainer = MagicMock()
trainer.callback_metrics = {optimization_metric: torch.as_tensor([1])}

mock_pl_trainer.callback_metrics = mock_fl_trainer.callback_metrics = {optimization_metric: torch.as_tensor([1])}

model.backend.optimization_objective({}, {})

Expand Down

0 comments on commit 546bad8

Please sign in to comment.