diff --git a/examples/nbs/02-TextClassification.ipynb b/examples/nbs/02-TextClassification.ipynb index 144d211..01aa307 100644 --- a/examples/nbs/02-TextClassification.ipynb +++ b/examples/nbs/02-TextClassification.ipynb @@ -42,7 +42,11 @@ "source": [ "from flash.core.data.utils import download_data\n", "from flash.text import TextClassificationData\n", - "from gradsflow import AutoTextClassifier" + "from gradsflow import AutoTextClassifier\n", + "\n", + "import ray\n", + "\n", + "ray.init(address=\"auto\")" ] }, { @@ -63,7 +67,7 @@ " \"sentiment\",\n", " train_file=\"data/imdb/train.csv\",\n", " val_file=\"data/imdb/valid.csv\",\n", - " batch_size=4,\n", + " batch_size=16,\n", ")" ] }, @@ -74,31 +78,77 @@ "metadata": { "pycharm": { "name": "#%%\n" - } + }, + "tags": [] }, "outputs": [], "source": [ "suggested_conf = dict(\n", - " optimizers=[\"adam\"],\n", - " lr=(5e-4, 1e-3),\n", + " optimizer=[\"adam\", \"adamw\"],\n", + " lr=(4e-3, 1e-2),\n", ")\n", "\n", "model = AutoTextClassifier(\n", " datamodule,\n", - " suggested_backbones=[\"sgugger/tiny-distilbert-classification\"],\n", + " suggested_backbones=[\"prajjwal1/bert-medium\"],\n", " suggested_conf=suggested_conf,\n", - " max_epochs=2,\n", + " max_epochs=1,\n", " optimization_metric=\"val_accuracy\",\n", + " n_trials=1,\n", ")\n", "\n", "print(\"AutoTextClassifier initialised!\")\n", "model.hp_tune()" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4a9c19a3", + "metadata": {}, + "outputs": [], + "source": [ + "model.analysis.dataframe()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d0b96bd0-3949-4dd5-b6d9-4fa6ac213e30", + "metadata": {}, + "outputs": [], + "source": [ + "from flash import Trainer\n", + "trainer = Trainer()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6a95bc01-4291-49a8-bea7-c76ba64bf42a", + "metadata": {}, + "outputs": [], + "source": [ + "trainer.validate(model.model, datamodule=datamodule)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b5b20a6f-9efd-493e-94ea-ce00e8879451", + "metadata": {}, + "outputs": [], + "source": [ + "ray.shutdown()" + ] } ], "metadata": { + "interpreter": { + "hash": "e6697cd4c0f4f58297a92a2dfda85db933b7e27cf6bc19e3dafb7e93fff75254" + }, "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -112,7 +162,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.10" + "version": "3.9.12" } }, "nbformat": 4, diff --git a/examples/src/tasks/text_classification.py b/examples/src/tasks/text_classification.py index e0029b9..5b820b0 100644 --- a/examples/src/tasks/text_classification.py +++ b/examples/src/tasks/text_classification.py @@ -14,16 +14,17 @@ ) suggested_conf = dict( - optimizers=["adam"], + optimizers=["adam", "adamw"], lr=(5e-4, 1e-3), ) model = AutoTextClassifier( datamodule, - suggested_backbones=["sgugger/tiny-distilbert-classification"], + suggested_backbones=["prajjwal1/bert-medium"], suggested_conf=suggested_conf, - max_epochs=2, + max_epochs=1, optimization_metric="val_accuracy", + n_trials=4, ) print("AutoTextClassifier initialised!") diff --git a/gradsflow/autotasks/autoclassification/text/text.py b/gradsflow/autotasks/autoclassification/text/text.py index bc2805c..2b4f688 100644 --- a/gradsflow/autotasks/autoclassification/text/text.py +++ b/gradsflow/autotasks/autoclassification/text/text.py @@ -37,6 +37,7 @@ class AutoTextClassifier(AutoClassifier): "sentiment", train_file="data/imdb/train.csv", val_file="data/imdb/valid.csv", + batch_size=4, ) model = AutoTextClassifier(datamodule, @@ -66,8 +67,8 @@ class AutoTextClassifier(AutoClassifier): "sgugger/tiny-distilbert-classification", ] - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init__(self, *args, max_steps=-1, **kwargs): + super().__init__(*args, max_steps=max_steps, **kwargs) meta = self.auto_dataset.meta self.num_classes = meta.get("num_labels") or meta.get("num_classes") logger.debug(f"num_classes = {self.num_classes}")