Skip to content

Commit d332ae3

Browse files
SkafteNickiBordamergify[bot]
authored
CLIP score improvements (#2978)
* add to examples * functional improvement * implement in tests * fix processor for jina * add to modular * add changelog * note * fix testing * add missing requirements * fix truncation issue * bit of refactor loading processor and model * fix stupid mistake in naming * fixes * try fixing typing * refactor * timm * einops * drop * try fixing tests * fix more tests * fix remaining tests * try fixing more tests * lower test burden * try fixing --------- Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> Co-authored-by: Jirka B <j.borovec+github@gmail.com> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
1 parent d3b33b6 commit d332ae3

File tree

9 files changed

+222
-49
lines changed

9 files changed

+222
-49
lines changed

CHANGELOG.md

+3
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2121
- Added `ARNIQA` metric to image domain ([#2953](https://github.com/PyTorchLightning/metrics/pull/2953))
2222

2323

24+
- Added support for more models and processors in `CLIPScore` ([#2978](https://github.com/PyTorchLightning/metrics/pull/2978))
25+
26+
2427
### Changed
2528

2629
-

examples/image/clip_score.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,7 @@
4141

4242
models = [
4343
"openai/clip-vit-base-patch16",
44-
# "openai/clip-vit-base-patch32",
45-
# "openai/clip-vit-large-patch14-336",
46-
"openai/clip-vit-large-patch14",
44+
# "zer0int/LongCLIP-L-Diffusers",
4745
]
4846

4947
# %%

requirements/multimodal.txt

+2
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,5 @@
33

44
transformers >=4.42.3, <4.50.0
55
piq <=0.8.0
6+
einops >=0.7.0, <=0.8.1 # CLIP dependency
7+
timm >=0.9.0, <1.1.0 # CLIP: needed for transformers/models--jinaai--jina-clip-implementation

src/torchmetrics/functional/multimodal/clip_score.py

+88-27
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import TYPE_CHECKING, List, Union, cast
14+
from typing import TYPE_CHECKING, Any, Callable, List, Union, cast
1515

1616
import torch
1717
from torch import Tensor
@@ -41,6 +41,24 @@ def _download_clip_for_clip_score() -> None:
4141
_CLIPProcessor = None
4242

4343

44+
class JinaProcessorWrapper:
45+
"""Wrapper class to convert tensors to PIL images if needed for Jina CLIP model."""
46+
47+
def __init__(self, processor: _CLIPProcessor) -> None:
48+
self.processor = processor
49+
50+
def __call__(self, *args: Any, **kwargs: Any) -> Any:
51+
"""Wrap the processor's __call__ method to convert tensors to PIL images if needed."""
52+
# Check if 'images' is in kwargs and convert tensors to PIL images if needed
53+
from torchvision.transforms.functional import to_pil_image
54+
55+
if "images" in kwargs:
56+
kwargs["images"] = [
57+
to_pil_image(img.float().cpu()) if isinstance(img, Tensor) else img for img in kwargs["images"]
58+
]
59+
return self.processor(*args, **kwargs)
60+
61+
4462
def _detect_modality(input_data: Union[Tensor, List[Tensor], List[str], str]) -> Literal["image", "text"]:
4563
"""Automatically detect the modality of the input data.
4664
@@ -110,22 +128,22 @@ def _get_features(
110128
111129
"""
112130
if modality == "image":
113-
# Add type checking for images
114-
image_data = [i for i in data if isinstance(i, Tensor)]
131+
image_data = [i for i in data if isinstance(i, Tensor)] # Add type checking for images
115132
processed = processor(images=[i.cpu() for i in image_data], return_tensors="pt", padding=True)
116133
return model.get_image_features(processed["pixel_values"].to(device))
117134
if modality == "text":
118135
processed = processor(text=data, return_tensors="pt", padding=True)
119-
max_position_embeddings = model.config.text_config.max_position_embeddings
120-
if processed["attention_mask"].shape[-1] > max_position_embeddings:
121-
rank_zero_warn(
122-
f"Encountered caption longer than {max_position_embeddings=}. Will truncate captions to this length."
123-
"If longer captions are needed, initialize argument `model_name_or_path` with a model that supports"
124-
"longer sequences",
125-
UserWarning,
126-
)
127-
processed["attention_mask"] = processed["attention_mask"][..., :max_position_embeddings]
128-
processed["input_ids"] = processed["input_ids"][..., :max_position_embeddings]
136+
if hasattr(model.config, "text_config") and hasattr(model.config.text_config, "max_position_embeddings"):
137+
max_position_embeddings = model.config.text_config.max_position_embeddings
138+
if processed["attention_mask"].shape[-1] > max_position_embeddings:
139+
rank_zero_warn(
140+
f"Encountered caption longer than {max_position_embeddings=}. Will truncate captions to this"
141+
"length. If longer captions are needed, initialize argument `model_name_or_path` with a model that"
142+
"supports longer sequences.",
143+
UserWarning,
144+
)
145+
processed["attention_mask"] = processed["attention_mask"][..., :max_position_embeddings]
146+
processed["input_ids"] = processed["input_ids"][..., :max_position_embeddings]
129147
return model.get_text_features(processed["input_ids"].to(device), processed["attention_mask"].to(device))
130148
raise ValueError(f"invalid modality {modality}")
131149

@@ -136,6 +154,7 @@ def _clip_score_update(
136154
model: _CLIPModel,
137155
processor: _CLIPProcessor,
138156
) -> tuple[Tensor, int]:
157+
"""Update function for CLIP Score."""
139158
source_modality = _detect_modality(source)
140159
target_modality = _detect_modality(target)
141160

@@ -181,19 +200,43 @@ def _clip_score_update(
181200

182201

183202
def _get_clip_model_and_processor(
184-
model_name_or_path: Literal[
185-
"openai/clip-vit-base-patch16",
186-
"openai/clip-vit-base-patch32",
187-
"openai/clip-vit-large-patch14-336",
188-
"openai/clip-vit-large-patch14",
189-
] = "openai/clip-vit-large-patch14",
203+
model_name_or_path: Union[
204+
Literal[
205+
"openai/clip-vit-base-patch16",
206+
"openai/clip-vit-base-patch32",
207+
"openai/clip-vit-large-patch14-336",
208+
"openai/clip-vit-large-patch14",
209+
"jinaai/jina-clip-v2",
210+
"zer0int/LongCLIP-L-Diffusers",
211+
"zer0int/LongCLIP-GmP-ViT-L-14",
212+
],
213+
Callable[[], tuple[_CLIPModel, _CLIPProcessor]],
214+
],
190215
) -> tuple[_CLIPModel, _CLIPProcessor]:
216+
if callable(model_name_or_path):
217+
return model_name_or_path()
218+
191219
if _TRANSFORMERS_GREATER_EQUAL_4_10:
220+
from transformers import AutoModel, AutoProcessor
221+
from transformers import CLIPConfig as _CLIPConfig
192222
from transformers import CLIPModel as _CLIPModel
193223
from transformers import CLIPProcessor as _CLIPProcessor
194224

195-
model = _CLIPModel.from_pretrained(model_name_or_path)
196-
processor = _CLIPProcessor.from_pretrained(model_name_or_path)
225+
if "openai" in model_name_or_path:
226+
model = _CLIPModel.from_pretrained(model_name_or_path)
227+
processor = _CLIPProcessor.from_pretrained(model_name_or_path)
228+
elif "jinaai" in model_name_or_path:
229+
model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True)
230+
processor = JinaProcessorWrapper(
231+
processor=AutoProcessor.from_pretrained(model_name_or_path, trust_remote_code=True)
232+
)
233+
elif "zer0int" in model_name_or_path:
234+
config = _CLIPConfig.from_pretrained(model_name_or_path)
235+
config.text_config.max_position_embeddings = 248
236+
model = _CLIPModel.from_pretrained(model_name_or_path, config=config)
237+
processor = _CLIPProcessor.from_pretrained(model_name_or_path, padding="max_length", max_length=248)
238+
else:
239+
raise ValueError(f"Invalid model_name_or_path {model_name_or_path}. Not supported by `clip_score` metric.")
197240
return model, processor
198241

199242
raise ModuleNotFoundError(
@@ -205,11 +248,17 @@ def _get_clip_model_and_processor(
205248
def clip_score(
206249
source: Union[Tensor, List[Tensor], List[str], str],
207250
target: Union[Tensor, List[Tensor], List[str], str],
208-
model_name_or_path: Literal[
209-
"openai/clip-vit-base-patch16",
210-
"openai/clip-vit-base-patch32",
211-
"openai/clip-vit-large-patch14-336",
212-
"openai/clip-vit-large-patch14",
251+
model_name_or_path: Union[
252+
Literal[
253+
"openai/clip-vit-base-patch16",
254+
"openai/clip-vit-base-patch32",
255+
"openai/clip-vit-large-patch14-336",
256+
"openai/clip-vit-large-patch14",
257+
"jinaai/jina-clip-v2",
258+
"zer0int/LongCLIP-L-Diffusers",
259+
"zer0int/LongCLIP-GmP-ViT-L-14",
260+
],
261+
Callable[[], tuple[_CLIPModel, _CLIPProcessor]],
213262
] = "openai/clip-vit-large-patch14",
214263
) -> Tensor:
215264
r"""Calculates `CLIP Score`_ which is a text-to-image similarity metric.
@@ -239,6 +288,11 @@ def clip_score(
239288
240289
.. note:: Metric is not scriptable
241290
291+
.. note::
292+
The default CLIP and processor used in this implementation has a maximum sequence length of 77 for text
293+
inputs. If you need to process longer captions, you can use the `zer0int/LongCLIP-L-Diffusers` model which
294+
has a maximum sequence length of 248.
295+
242296
Args:
243297
source: Source input. This can be:
244298
- Images: Either a single [N, C, H, W] tensor or a list of [C, H, W] tensors.
@@ -251,7 +305,14 @@ def clip_score(
251305
- `"openai/clip-vit-base-patch32"`
252306
- `"openai/clip-vit-large-patch14-336"`
253307
- `"openai/clip-vit-large-patch14"`
254-
308+
- `"jinaai/jina-clip-v2"`
309+
- `"zer0int/LongCLIP-L-Diffusers"`
310+
- `"zer0int/LongCLIP-GmP-ViT-L-14"`
311+
312+
Alternatively, a callable function that returns a tuple of CLIP compatible model and processor instances
313+
can be passed in. By compatible, we mean that the processors `__call__` method should accept a list of
314+
strings and list of images and that the model should have a `get_image_features` and `get_text_features`
315+
methods.
255316
256317
Raises:
257318
ModuleNotFoundError:

src/torchmetrics/multimodal/clip_score.py

+31-6
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Any, List, Optional, Sequence, Union
14+
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Sequence, Union
1515

1616
import torch
1717
from torch import Tensor
@@ -26,6 +26,10 @@
2626
if not _MATPLOTLIB_AVAILABLE:
2727
__doctest_skip__ = ["CLIPScore.plot"]
2828

29+
if TYPE_CHECKING and _TRANSFORMERS_GREATER_EQUAL_4_10:
30+
from transformers import CLIPModel as _CLIPModel
31+
from transformers import CLIPProcessor as _CLIPProcessor
32+
2933
if _SKIP_SLOW_DOCTEST and _TRANSFORMERS_GREATER_EQUAL_4_10:
3034
from transformers import CLIPModel as _CLIPModel
3135
from transformers import CLIPProcessor as _CLIPProcessor
@@ -38,6 +42,8 @@ def _download_clip_for_clip_score() -> None:
3842
__doctest_skip__ = ["CLIPScore", "CLIPScore.plot"]
3943
else:
4044
__doctest_skip__ = ["CLIPScore", "CLIPScore.plot"]
45+
_CLIPModel = None
46+
_CLIPProcessor = None
4147

4248

4349
class CLIPScore(Metric):
@@ -69,6 +75,11 @@ class CLIPScore(Metric):
6975
.. caution::
7076
Metric is not scriptable
7177
78+
.. note::
79+
The default CLIP and processor used in this implementation has a maximum sequence length of 77 for text
80+
inputs. If you need to process longer captions, you can use the `zer0int/LongCLIP-L-Diffusers` model which
81+
has a maximum sequence length of 248.
82+
7283
As input to ``forward`` and ``update`` the metric accepts the following input
7384
7485
- source: Source input.
@@ -110,6 +121,14 @@ class CLIPScore(Metric):
110121
- `"openai/clip-vit-base-patch32"`
111122
- `"openai/clip-vit-large-patch14-336"`
112123
- `"openai/clip-vit-large-patch14"`
124+
- `"jinaai/jina-clip-v2"`
125+
- `"zer0int/LongCLIP-L-Diffusers"`
126+
- `"zer0int/LongCLIP-GmP-ViT-L-14"`
127+
128+
Alternatively, a callable function that returns a tuple of CLIP compatible model and processor instances
129+
can be passed in. By compatible, we mean that the processors `__call__` method should accept a list of
130+
strings and list of images and that the model should have a `get_image_features` and `get_text_features`
131+
methods.
113132
114133
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
115134
@@ -156,11 +175,17 @@ class CLIPScore(Metric):
156175

157176
def __init__(
158177
self,
159-
model_name_or_path: Literal[
160-
"openai/clip-vit-base-patch16",
161-
"openai/clip-vit-base-patch32",
162-
"openai/clip-vit-large-patch14-336",
163-
"openai/clip-vit-large-patch14",
178+
model_name_or_path: Union[
179+
Literal[
180+
"openai/clip-vit-base-patch16",
181+
"openai/clip-vit-base-patch32",
182+
"openai/clip-vit-large-patch14-336",
183+
"openai/clip-vit-large-patch14",
184+
"jinaai/jina-clip-v2",
185+
"zer0int/LongCLIP-L-Diffusers",
186+
"zer0int/LongCLIP-GmP-ViT-L-14",
187+
],
188+
Callable[[], tuple[_CLIPModel, _CLIPProcessor]],
164189
] = "openai/clip-vit-large-patch14",
165190
**kwargs: Any,
166191
) -> None:

tests/unittests/_helpers/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import numpy
1717
import torch
1818

19-
from unittests._helpers.wrappers import skip_on_connection_issues, skip_on_running_out_of_memory
19+
from unittests._helpers.wrappers import skip_on_connection_issues, skip_on_cuda_oom, skip_on_running_out_of_memory
2020

2121

2222
def seed_all(seed):
@@ -27,4 +27,4 @@ def seed_all(seed):
2727
torch.cuda.manual_seed_all(seed)
2828

2929

30-
__all__ = ["seed_all", "skip_on_connection_issues", "skip_on_running_out_of_memory"]
30+
__all__ = ["seed_all", "skip_on_connection_issues", "skip_on_cuda_oom", "skip_on_running_out_of_memory"]

tests/unittests/_helpers/testers.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ def _class_test(
9898
fragment_kwargs: bool = False,
9999
check_scriptable: bool = True,
100100
check_state_dict: bool = True,
101+
check_picklable: bool = True,
101102
**kwargs_update: Any,
102103
):
103104
"""Comparison between class metric and reference metric.
@@ -121,6 +122,7 @@ def _class_test(
121122
fragment_kwargs: whether tensors in kwargs should be divided as `preds` and `target` among processes
122123
check_scriptable: bool indicating if metric should also be tested if it can be scripted
123124
check_state_dict: bool indicating if metric should be tested that its state_dict by default is empty
125+
check_picklable: bool indicating if metric should be tested that it can be pickled
124126
kwargs_update: Additional keyword arguments that will be passed with preds and
125127
target when running update on the metric.
126128
@@ -156,8 +158,9 @@ def _class_test(
156158
kwargs_update = {k: v.to(device) if isinstance(v, Tensor) else v for k, v in kwargs_update.items()}
157159

158160
# verify metrics work after being loaded from pickled state
159-
pickled_metric = pickle.dumps(metric)
160-
metric = pickle.loads(pickled_metric)
161+
if check_picklable:
162+
pickled_metric = pickle.dumps(metric)
163+
metric = pickle.loads(pickled_metric)
161164
metric_clone = deepcopy(metric)
162165

163166
for i in range(rank, num_batches, world_size):
@@ -431,6 +434,7 @@ def run_class_metric_test(
431434
fragment_kwargs: bool = False,
432435
check_scriptable: bool = True,
433436
check_state_dict: bool = True,
437+
check_picklable: bool = True,
434438
atol: Optional[float] = None,
435439
**kwargs_update: Any,
436440
):
@@ -451,6 +455,7 @@ def run_class_metric_test(
451455
fragment_kwargs: whether tensors in kwargs should be divided as `preds` and `target` among processes
452456
check_scriptable: bool indicating if metric should also be tested if it can be scripted
453457
check_state_dict: bool indicating if metric should be tested that its state_dict by default is empty
458+
check_picklable: bool indicating if metric should be tested that it can be pickled
454459
atol: absolute tolerance used for comparison of results, if None will use self.atol
455460
kwargs_update: Additional keyword arguments that will be passed with preds and
456461
target when running update on the metric.
@@ -470,6 +475,7 @@ def run_class_metric_test(
470475
"fragment_kwargs": fragment_kwargs,
471476
"check_scriptable": check_scriptable,
472477
"check_state_dict": check_state_dict,
478+
"check_picklable": check_picklable,
473479
}
474480

475481
if ddp and hasattr(pytest, "pool"):

tests/unittests/_helpers/wrappers.py

+22
Original file line numberDiff line numberDiff line change
@@ -69,3 +69,25 @@ def run_test(*args: Any, **kwargs: Any) -> Optional[Any]:
6969
return run_test
7070

7171
return test_decorator
72+
73+
74+
def skip_on_cuda_oom(reason: str = "Skipping test due to CUDA Out of Memory (OOM) error."):
75+
"""Skip tests that fail due to CUDA Out of Memory (OOM) errors.
76+
77+
The test runs normally if no OOM error arises, but is marked as skipped otherwise.
78+
79+
"""
80+
81+
def test_decorator(function: Callable) -> Callable:
82+
@wraps(function)
83+
def run_test(*args: Any, **kwargs: Any) -> Optional[Any]:
84+
try:
85+
return function(*args, **kwargs)
86+
except RuntimeError as ex:
87+
if "CUDA out of memory" not in str(ex):
88+
raise ex
89+
pytest.skip(reason)
90+
91+
return run_test
92+
93+
return test_decorator

0 commit comments

Comments
 (0)