Skip to content

Commit 5fc2e0b

Browse files
rittik9Bordamergify[bot]pre-commit-ci[bot]
authored
Enhance Clip_Score to calculate similarities between same modalities (#2875)
* Handle zero division error in binary IoU (Jaccard index) calculation --------- Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 520a868 commit 5fc2e0b

File tree

3 files changed

+330
-69
lines changed

3 files changed

+330
-69
lines changed

src/torchmetrics/functional/multimodal/clip_score.py

+178-52
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import TYPE_CHECKING, List, Union
14+
from typing import TYPE_CHECKING, List, Union, cast
1515

1616
import torch
1717
from torch import Tensor
@@ -41,53 +41,143 @@ def _download_clip_for_clip_score() -> None:
4141
_CLIPProcessor = None
4242

4343

44+
def _detect_modality(input_data: Union[Tensor, List[Tensor], List[str], str]) -> Literal["image", "text"]:
45+
"""Automatically detect the modality of the input data.
46+
47+
Args:
48+
input_data: Input data that can be either image tensors or text strings
49+
50+
Returns:
51+
str: Either "image" or "text"
52+
53+
Raises:
54+
ValueError: If the input_data is an empty list or modality cannot be determined
55+
56+
"""
57+
if isinstance(input_data, Tensor):
58+
return "image"
59+
60+
if isinstance(input_data, list):
61+
if len(input_data) == 0:
62+
raise ValueError("Empty input list")
63+
if isinstance(input_data[0], Tensor):
64+
return "image"
65+
if isinstance(input_data[0], str):
66+
return "text"
67+
68+
if isinstance(input_data, str):
69+
return "text"
70+
71+
raise ValueError("Could not automatically determine modality for input_data")
72+
73+
74+
def _process_image_data(images: Union[Tensor, List[Tensor]]) -> List[Tensor]:
75+
"""Helper function to process image data."""
76+
images = [images] if not isinstance(images, list) and images.ndim == 3 else list(images)
77+
if not all(i.ndim == 3 for i in images):
78+
raise ValueError("Expected all images to be 3d but found image that has either more or less")
79+
return images
80+
81+
82+
def _process_text_data(texts: Union[str, List[str]]) -> List[str]:
83+
"""Helper function to process text data."""
84+
if not isinstance(texts, list):
85+
texts = [texts]
86+
return texts
87+
88+
89+
def _get_features(
90+
data: List[Union[Tensor, str]],
91+
modality: str,
92+
device: torch.device,
93+
model: "_CLIPModel",
94+
processor: "_CLIPProcessor",
95+
) -> Tensor:
96+
"""Get features from the CLIP model for either images or text.
97+
98+
Args:
99+
data: List of input data (images or text)
100+
modality: String indicating the type of input data (must be either "image" or "text")
101+
device: Device to run the model on
102+
model: CLIP model instance
103+
processor: CLIP processor instance
104+
105+
Returns:
106+
Tensor of features from the CLIP model
107+
108+
Raises:
109+
ValueError: If modality is not "image" or "text"
110+
111+
"""
112+
if modality == "image":
113+
# Add type checking for images
114+
image_data = [i for i in data if isinstance(i, Tensor)]
115+
processed = processor(images=[i.cpu() for i in image_data], return_tensors="pt", padding=True)
116+
return model.get_image_features(processed["pixel_values"].to(device))
117+
if modality == "text":
118+
processed = processor(text=data, return_tensors="pt", padding=True)
119+
max_position_embeddings = model.config.text_config.max_position_embeddings
120+
if processed["attention_mask"].shape[-1] > max_position_embeddings:
121+
rank_zero_warn(
122+
f"Encountered caption longer than {max_position_embeddings=}. Will truncate captions to this length."
123+
"If longer captions are needed, initialize argument `model_name_or_path` with a model that supports"
124+
"longer sequences",
125+
UserWarning,
126+
)
127+
processed["attention_mask"] = processed["attention_mask"][..., :max_position_embeddings]
128+
processed["input_ids"] = processed["input_ids"][..., :max_position_embeddings]
129+
return model.get_text_features(processed["input_ids"].to(device), processed["attention_mask"].to(device))
130+
raise ValueError(f"invalid modality {modality}")
131+
132+
44133
def _clip_score_update(
45-
images: Union[Tensor, List[Tensor]],
46-
text: Union[str, list[str]],
134+
source: Union[Tensor, List[Tensor], List[str], str],
135+
target: Union[Tensor, List[Tensor], List[str], str],
47136
model: _CLIPModel,
48137
processor: _CLIPProcessor,
49138
) -> tuple[Tensor, int]:
50-
if not isinstance(images, list):
51-
if images.ndim == 3:
52-
images = [images]
53-
else: # unwrap into list
54-
images = list(images)
55-
56-
if not all(i.ndim == 3 for i in images):
57-
raise ValueError("Expected all images to be 3d but found image that has either more or less")
139+
source_modality = _detect_modality(source)
140+
target_modality = _detect_modality(target)
58141

59-
if not isinstance(text, list):
60-
text = [text]
142+
source_data = (
143+
_process_image_data(cast(Union[Tensor, List[Tensor]], source))
144+
if source_modality == "image"
145+
else _process_text_data(cast(Union[str, List[str]], source))
146+
)
147+
target_data = (
148+
_process_image_data(cast(Union[Tensor, List[Tensor]], target))
149+
if target_modality == "image"
150+
else _process_text_data(cast(Union[str, List[str]], target))
151+
)
61152

62-
if len(text) != len(images):
153+
if len(source_data) != len(target_data):
63154
raise ValueError(
64-
f"Expected the number of images and text examples to be the same but got {len(images)} and {len(text)}"
65-
)
66-
device = images[0].device
67-
processed_input = processor(text=text, images=[i.cpu() for i in images], return_tensors="pt", padding=True)
68-
69-
img_features = model.get_image_features(processed_input["pixel_values"].to(device))
70-
img_features = img_features / img_features.norm(p=2, dim=-1, keepdim=True)
71-
72-
max_position_embeddings = model.config.text_config.max_position_embeddings
73-
if processed_input["attention_mask"].shape[-1] > max_position_embeddings:
74-
rank_zero_warn(
75-
f"Encountered caption longer than {max_position_embeddings=}. Will truncate captions to this length."
76-
"If longer captions are needed, initialize argument `model_name_or_path` with a model that supports"
77-
"longer sequences",
78-
UserWarning,
155+
"Expected the number of source and target examples to be the same but got "
156+
f"{len(source_data)} and {len(target_data)}"
79157
)
80-
processed_input["attention_mask"] = processed_input["attention_mask"][..., :max_position_embeddings]
81-
processed_input["input_ids"] = processed_input["input_ids"][..., :max_position_embeddings]
82158

83-
txt_features = model.get_text_features(
84-
processed_input["input_ids"].to(device), processed_input["attention_mask"].to(device)
159+
device = (
160+
source_data[0].device
161+
if source_modality == "image" and isinstance(source_data[0], Tensor)
162+
else target_data[0].device
163+
if target_modality == "image" and isinstance(target_data[0], Tensor)
164+
else torch.device("cuda" if torch.cuda.is_available() else "cpu")
85165
)
86-
txt_features = txt_features / txt_features.norm(p=2, dim=-1, keepdim=True)
166+
model = model.to(device)
87167

88-
# cosine similarity between feature vectors
89-
score = 100 * (img_features * txt_features).sum(axis=-1)
90-
return score, len(text)
168+
source_features = _get_features(
169+
cast(List[Union[Tensor, str]], source_data), source_modality, device, model, processor
170+
)
171+
target_features = _get_features(
172+
cast(List[Union[Tensor, str]], target_data), target_modality, device, model, processor
173+
)
174+
source_features = source_features / source_features.norm(p=2, dim=-1, keepdim=True)
175+
target_features = target_features / target_features.norm(p=2, dim=-1, keepdim=True)
176+
177+
# Calculate cosine similarity
178+
score = 100 * (source_features * target_features).sum(axis=-1)
179+
score = score.cpu() if source_modality == "text" and target_modality == "text" else score
180+
return score, len(source_data)
91181

92182

93183
def _get_clip_model_and_processor(
@@ -113,20 +203,20 @@ def _get_clip_model_and_processor(
113203

114204

115205
def clip_score(
116-
images: Union[Tensor, List[Tensor]],
117-
text: Union[str, list[str]],
206+
source: Union[Tensor, List[Tensor], List[str], str],
207+
target: Union[Tensor, List[Tensor], List[str], str],
118208
model_name_or_path: Literal[
119209
"openai/clip-vit-base-patch16",
120210
"openai/clip-vit-base-patch32",
121211
"openai/clip-vit-large-patch14-336",
122212
"openai/clip-vit-large-patch14",
123213
] = "openai/clip-vit-large-patch14",
124214
) -> Tensor:
125-
r"""Calculate `CLIP Score`_ which is a text-to-image similarity metric.
215+
r"""Calculates `CLIP Score`_ which is a text-to-image similarity metric.
126216
127217
CLIP Score is a reference free metric that can be used to evaluate the correlation between a generated caption for
128-
an image and the actual content of the image. It has been found to be highly correlated with human judgement. The
129-
metric is defined as:
218+
an image and the actual content of the image, as well as the similarity between texts or images. It has been found
219+
to be highly correlated with human judgement. The metric is defined as:
130220
131221
.. math::
132222
\text{CLIPScore(I, C)} = max(100 * cos(E_I, E_C), 0)
@@ -135,15 +225,33 @@ def clip_score(
135225
textual CLIP embedding :math:`E_C` for an caption :math:`C`. The score is bound between 0 and 100 and the closer
136226
to 100 the better.
137227
138-
.. caution::
139-
Metric is not scriptable
228+
Additionally, the CLIP Score can be calculated for the same modalities:
229+
230+
.. math::
231+
\text{CLIPScore(I_1, I_2)} = max(100 * cos(E_{I_1}, E_{I_2}), 0)
232+
233+
where :math:`E_{I_1}` and :math:`E_{I_2}` are the visual embeddings for images :math:`I_1` and :math:`I_2`.
234+
235+
.. math::
236+
\text{CLIPScore(T_1, T_2)} = max(100 * cos(E_{T_1}, E_{T_2}), 0)
237+
238+
where :math:`E_{T_1}` and :math:`E_{T_2}` are the textual embeddings for texts :math:`T_1` and :math:`T_2`.
239+
240+
.. note:: Metric is not scriptable
140241
141242
Args:
142-
images: Either a single [N, C, H, W] tensor or a list of [C, H, W] tensors
143-
text: Either a single caption or a list of captions
144-
model_name_or_path: string indicating the version of the CLIP model to use. Available models are
145-
`"openai/clip-vit-base-patch16"`, `"openai/clip-vit-base-patch32"`, `"openai/clip-vit-large-patch14-336"`
146-
and `"openai/clip-vit-large-patch14"`,
243+
source: Source input. This can be:
244+
- Images: Either a single [N, C, H, W] tensor or a list of [C, H, W] tensors.
245+
- Text: Either a single caption or a list of captions.
246+
target: Target input. This can be:
247+
- Images: Either a single [N, C, H, W] tensor or a list of [C, H, W] tensors.
248+
- Text: Either a single caption or a list of captions.
249+
model_name_or_path: String indicating the version of the CLIP model to use. Available models are:
250+
- `"openai/clip-vit-base-patch16"`
251+
- `"openai/clip-vit-base-patch32"`
252+
- `"openai/clip-vit-large-patch14-336"`
253+
- `"openai/clip-vit-large-patch14"`
254+
147255
148256
Raises:
149257
ModuleNotFoundError:
@@ -155,13 +263,31 @@ def clip_score(
155263
156264
Example:
157265
>>> from torchmetrics.functional.multimodal import clip_score
158-
>>> score = clip_score(torch.randint(255, (3, 224, 224)), "a photo of a cat", "openai/clip-vit-base-patch16")
266+
>>> image = torch.randint(255, (3, 224, 224), generator=torch.Generator().manual_seed(42))
267+
>>> score = clip_score(image, "a photo of a cat", "openai/clip-vit-base-patch16")
159268
>>> score.detach()
160269
tensor(24.4255)
161270
271+
Example:
272+
>>> from torchmetrics.functional.multimodal import clip_score
273+
>>> image1 = torch.randint(255, (3, 224, 224), generator=torch.Generator().manual_seed(42))
274+
>>> image2 = torch.randint(255, (3, 224, 224), generator=torch.Generator().manual_seed(43))
275+
>>> score = clip_score(image1, image2, "openai/clip-vit-base-patch16")
276+
>>> score.detach()
277+
tensor(99.4859)
278+
279+
Example:
280+
>>> from torchmetrics.functional.multimodal import clip_score
281+
>>> score = clip_score(
282+
... "28-year-old chef found dead in San Francisco mall",
283+
... "A 28-year-old chef who recently moved to San Francisco was found dead.",
284+
... "openai/clip-vit-base-patch16"
285+
... )
286+
>>> score.detach()
287+
tensor(91.3950)
288+
162289
"""
163290
model, processor = _get_clip_model_and_processor(model_name_or_path)
164-
device = images.device if isinstance(images, Tensor) else images[0].device
165-
score, _ = _clip_score_update(images, text, model.to(device), processor)
291+
score, _ = _clip_score_update(source, target, model, processor)
166292
score = score.mean(0)
167293
return torch.max(score, torch.zeros_like(score))

0 commit comments

Comments
 (0)