Skip to content

Commit

Permalink
refactor apis (#184)
Browse files Browse the repository at this point in the history
  • Loading branch information
aniketmaurya authored May 18, 2022
1 parent 546bad8 commit 74eced1
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 14 deletions.
68 changes: 59 additions & 9 deletions examples/nbs/02-TextClassification.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,11 @@
"source": [
"from flash.core.data.utils import download_data\n",
"from flash.text import TextClassificationData\n",
"from gradsflow import AutoTextClassifier"
"from gradsflow import AutoTextClassifier\n",
"\n",
"import ray\n",
"\n",
"ray.init(address=\"auto\")"
]
},
{
Expand All @@ -63,7 +67,7 @@
" \"sentiment\",\n",
" train_file=\"data/imdb/train.csv\",\n",
" val_file=\"data/imdb/valid.csv\",\n",
" batch_size=4,\n",
" batch_size=16,\n",
")"
]
},
Expand All @@ -74,31 +78,77 @@
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"tags": []
},
"outputs": [],
"source": [
"suggested_conf = dict(\n",
" optimizers=[\"adam\"],\n",
" lr=(5e-4, 1e-3),\n",
" optimizer=[\"adam\", \"adamw\"],\n",
" lr=(4e-3, 1e-2),\n",
")\n",
"\n",
"model = AutoTextClassifier(\n",
" datamodule,\n",
" suggested_backbones=[\"sgugger/tiny-distilbert-classification\"],\n",
" suggested_backbones=[\"prajjwal1/bert-medium\"],\n",
" suggested_conf=suggested_conf,\n",
" max_epochs=2,\n",
" max_epochs=1,\n",
" optimization_metric=\"val_accuracy\",\n",
" n_trials=1,\n",
")\n",
"\n",
"print(\"AutoTextClassifier initialised!\")\n",
"model.hp_tune()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4a9c19a3",
"metadata": {},
"outputs": [],
"source": [
"model.analysis.dataframe()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d0b96bd0-3949-4dd5-b6d9-4fa6ac213e30",
"metadata": {},
"outputs": [],
"source": [
"from flash import Trainer\n",
"trainer = Trainer()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6a95bc01-4291-49a8-bea7-c76ba64bf42a",
"metadata": {},
"outputs": [],
"source": [
"trainer.validate(model.model, datamodule=datamodule)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b5b20a6f-9efd-493e-94ea-ce00e8879451",
"metadata": {},
"outputs": [],
"source": [
"ray.shutdown()"
]
}
],
"metadata": {
"interpreter": {
"hash": "e6697cd4c0f4f58297a92a2dfda85db933b7e27cf6bc19e3dafb7e93fff75254"
},
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -112,7 +162,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.10"
"version": "3.9.12"
}
},
"nbformat": 4,
Expand Down
7 changes: 4 additions & 3 deletions examples/src/tasks/text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,17 @@
)

suggested_conf = dict(
optimizers=["adam"],
optimizers=["adam", "adamw"],
lr=(5e-4, 1e-3),
)

model = AutoTextClassifier(
datamodule,
suggested_backbones=["sgugger/tiny-distilbert-classification"],
suggested_backbones=["prajjwal1/bert-medium"],
suggested_conf=suggested_conf,
max_epochs=2,
max_epochs=1,
optimization_metric="val_accuracy",
n_trials=4,
)

print("AutoTextClassifier initialised!")
Expand Down
5 changes: 3 additions & 2 deletions gradsflow/autotasks/autoclassification/text/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class AutoTextClassifier(AutoClassifier):
"sentiment",
train_file="data/imdb/train.csv",
val_file="data/imdb/valid.csv",
batch_size=4,
)
model = AutoTextClassifier(datamodule,
Expand Down Expand Up @@ -66,8 +67,8 @@ class AutoTextClassifier(AutoClassifier):
"sgugger/tiny-distilbert-classification",
]

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __init__(self, *args, max_steps=-1, **kwargs):
super().__init__(*args, max_steps=max_steps, **kwargs)
meta = self.auto_dataset.meta
self.num_classes = meta.get("num_labels") or meta.get("num_classes")
logger.debug(f"num_classes = {self.num_classes}")
Expand Down

0 comments on commit 74eced1

Please sign in to comment.