diff --git a/requirements/pytorch/examples.txt b/requirements/pytorch/examples.txt index a1e0c0ab45e9e..a9dd656394c7c 100644 --- a/requirements/pytorch/examples.txt +++ b/requirements/pytorch/examples.txt @@ -1,6 +1,7 @@ # NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package # in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment +requests <2.32.0 torchvision >=0.14.0, <0.17.0 gym[classic_control] >=0.17.0, <0.27.0 ipython[all] <8.15.0 diff --git a/src/lightning/pytorch/demos/transformer.py b/src/lightning/pytorch/demos/transformer.py index abb9e2e85a05c..4cd96115807ab 100644 --- a/src/lightning/pytorch/demos/transformer.py +++ b/src/lightning/pytorch/demos/transformer.py @@ -9,16 +9,19 @@ from pathlib import Path from typing import Dict, List, Optional, Tuple -import requests import torch import torch.nn as nn import torch.nn.functional as F +from lightning_utilities.core.imports import RequirementCache from torch import Tensor from torch.nn.modules import MultiheadAttention from torch.utils.data import DataLoader, Dataset from lightning.pytorch import LightningModule +_REQUESTS_AVAILABLE = RequirementCache("requests") + + if hasattr(MultiheadAttention, "_reset_parameters") and not hasattr(MultiheadAttention, "reset_parameters"): # See https://github.com/pytorch/pytorch/issues/107909 MultiheadAttention.reset_parameters = MultiheadAttention._reset_parameters @@ -125,6 +128,11 @@ def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]: @staticmethod def download(destination: Path) -> None: + if not _REQUESTS_AVAILABLE: + raise ModuleNotFoundError(str(_REQUESTS_AVAILABLE)) + + import requests + os.makedirs(destination.parent, exist_ok=True) url = "https://raw.githubusercontent.com/pytorch/examples/main/word_language_model/data/wikitext-2/train.txt" if os.path.exists(destination):