diff --git a/lightly/utils/dependency.py b/lightly/utils/dependency.py index 1dd2512d0..6279e3584 100644 --- a/lightly/utils/dependency.py +++ b/lightly/utils/dependency.py @@ -46,13 +46,16 @@ def timm_vit_available() -> bool: @functools.lru_cache(maxsize=1) def torchvision_transforms_v2_available() -> bool: - """Checks if torchvision supports the transforms.v2 API. + """Checks if torchvision supports the v2 transforms API with the `tv_tensors` + module. Checking for the availability of the `transforms.v2` is not sufficient + since it is available in torchvision >= 0.15.1, but the `tv_tensors` module is + only available in torchvision >= 0.16.0. Returns: True if transforms.v2 are available, False otherwise """ try: - from torchvision.tv_tensors import Mask + from torchvision import tv_tensors except ImportError: return False return True