11
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
- from typing import TYPE_CHECKING , List , Union
14
+ from typing import TYPE_CHECKING , List , Union , cast
15
15
16
16
import torch
17
17
from torch import Tensor
@@ -41,53 +41,143 @@ def _download_clip_for_clip_score() -> None:
41
41
_CLIPProcessor = None
42
42
43
43
44
+ def _detect_modality (input_data : Union [Tensor , List [Tensor ], List [str ], str ]) -> Literal ["image" , "text" ]:
45
+ """Automatically detect the modality of the input data.
46
+
47
+ Args:
48
+ input_data: Input data that can be either image tensors or text strings
49
+
50
+ Returns:
51
+ str: Either "image" or "text"
52
+
53
+ Raises:
54
+ ValueError: If the input_data is an empty list or modality cannot be determined
55
+
56
+ """
57
+ if isinstance (input_data , Tensor ):
58
+ return "image"
59
+
60
+ if isinstance (input_data , list ):
61
+ if len (input_data ) == 0 :
62
+ raise ValueError ("Empty input list" )
63
+ if isinstance (input_data [0 ], Tensor ):
64
+ return "image"
65
+ if isinstance (input_data [0 ], str ):
66
+ return "text"
67
+
68
+ if isinstance (input_data , str ):
69
+ return "text"
70
+
71
+ raise ValueError ("Could not automatically determine modality for input_data" )
72
+
73
+
74
+ def _process_image_data (images : Union [Tensor , List [Tensor ]]) -> List [Tensor ]:
75
+ """Helper function to process image data."""
76
+ images = [images ] if not isinstance (images , list ) and images .ndim == 3 else list (images )
77
+ if not all (i .ndim == 3 for i in images ):
78
+ raise ValueError ("Expected all images to be 3d but found image that has either more or less" )
79
+ return images
80
+
81
+
82
+ def _process_text_data (texts : Union [str , List [str ]]) -> List [str ]:
83
+ """Helper function to process text data."""
84
+ if not isinstance (texts , list ):
85
+ texts = [texts ]
86
+ return texts
87
+
88
+
89
+ def _get_features (
90
+ data : List [Union [Tensor , str ]],
91
+ modality : str ,
92
+ device : torch .device ,
93
+ model : "_CLIPModel" ,
94
+ processor : "_CLIPProcessor" ,
95
+ ) -> Tensor :
96
+ """Get features from the CLIP model for either images or text.
97
+
98
+ Args:
99
+ data: List of input data (images or text)
100
+ modality: String indicating the type of input data (must be either "image" or "text")
101
+ device: Device to run the model on
102
+ model: CLIP model instance
103
+ processor: CLIP processor instance
104
+
105
+ Returns:
106
+ Tensor of features from the CLIP model
107
+
108
+ Raises:
109
+ ValueError: If modality is not "image" or "text"
110
+
111
+ """
112
+ if modality == "image" :
113
+ # Add type checking for images
114
+ image_data = [i for i in data if isinstance (i , Tensor )]
115
+ processed = processor (images = [i .cpu () for i in image_data ], return_tensors = "pt" , padding = True )
116
+ return model .get_image_features (processed ["pixel_values" ].to (device ))
117
+ if modality == "text" :
118
+ processed = processor (text = data , return_tensors = "pt" , padding = True )
119
+ max_position_embeddings = model .config .text_config .max_position_embeddings
120
+ if processed ["attention_mask" ].shape [- 1 ] > max_position_embeddings :
121
+ rank_zero_warn (
122
+ f"Encountered caption longer than { max_position_embeddings = } . Will truncate captions to this length."
123
+ "If longer captions are needed, initialize argument `model_name_or_path` with a model that supports"
124
+ "longer sequences" ,
125
+ UserWarning ,
126
+ )
127
+ processed ["attention_mask" ] = processed ["attention_mask" ][..., :max_position_embeddings ]
128
+ processed ["input_ids" ] = processed ["input_ids" ][..., :max_position_embeddings ]
129
+ return model .get_text_features (processed ["input_ids" ].to (device ), processed ["attention_mask" ].to (device ))
130
+ raise ValueError (f"invalid modality { modality } " )
131
+
132
+
44
133
def _clip_score_update (
45
- images : Union [Tensor , List [Tensor ]],
46
- text : Union [str , list [ str ]],
134
+ source : Union [Tensor , List [Tensor ], List [ str ], str ],
135
+ target : Union [Tensor , List [ Tensor ], List [ str ], str ],
47
136
model : _CLIPModel ,
48
137
processor : _CLIPProcessor ,
49
138
) -> tuple [Tensor , int ]:
50
- if not isinstance (images , list ):
51
- if images .ndim == 3 :
52
- images = [images ]
53
- else : # unwrap into list
54
- images = list (images )
55
-
56
- if not all (i .ndim == 3 for i in images ):
57
- raise ValueError ("Expected all images to be 3d but found image that has either more or less" )
139
+ source_modality = _detect_modality (source )
140
+ target_modality = _detect_modality (target )
58
141
59
- if not isinstance (text , list ):
60
- text = [text ]
142
+ source_data = (
143
+ _process_image_data (cast (Union [Tensor , List [Tensor ]], source ))
144
+ if source_modality == "image"
145
+ else _process_text_data (cast (Union [str , List [str ]], source ))
146
+ )
147
+ target_data = (
148
+ _process_image_data (cast (Union [Tensor , List [Tensor ]], target ))
149
+ if target_modality == "image"
150
+ else _process_text_data (cast (Union [str , List [str ]], target ))
151
+ )
61
152
62
- if len (text ) != len (images ):
153
+ if len (source_data ) != len (target_data ):
63
154
raise ValueError (
64
- f"Expected the number of images and text examples to be the same but got { len (images )} and { len (text )} "
65
- )
66
- device = images [0 ].device
67
- processed_input = processor (text = text , images = [i .cpu () for i in images ], return_tensors = "pt" , padding = True )
68
-
69
- img_features = model .get_image_features (processed_input ["pixel_values" ].to (device ))
70
- img_features = img_features / img_features .norm (p = 2 , dim = - 1 , keepdim = True )
71
-
72
- max_position_embeddings = model .config .text_config .max_position_embeddings
73
- if processed_input ["attention_mask" ].shape [- 1 ] > max_position_embeddings :
74
- rank_zero_warn (
75
- f"Encountered caption longer than { max_position_embeddings = } . Will truncate captions to this length."
76
- "If longer captions are needed, initialize argument `model_name_or_path` with a model that supports"
77
- "longer sequences" ,
78
- UserWarning ,
155
+ "Expected the number of source and target examples to be the same but got "
156
+ f"{ len (source_data )} and { len (target_data )} "
79
157
)
80
- processed_input ["attention_mask" ] = processed_input ["attention_mask" ][..., :max_position_embeddings ]
81
- processed_input ["input_ids" ] = processed_input ["input_ids" ][..., :max_position_embeddings ]
82
158
83
- txt_features = model .get_text_features (
84
- processed_input ["input_ids" ].to (device ), processed_input ["attention_mask" ].to (device )
159
+ device = (
160
+ source_data [0 ].device
161
+ if source_modality == "image" and isinstance (source_data [0 ], Tensor )
162
+ else target_data [0 ].device
163
+ if target_modality == "image" and isinstance (target_data [0 ], Tensor )
164
+ else torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
85
165
)
86
- txt_features = txt_features / txt_features . norm ( p = 2 , dim = - 1 , keepdim = True )
166
+ model = model . to ( device )
87
167
88
- # cosine similarity between feature vectors
89
- score = 100 * (img_features * txt_features ).sum (axis = - 1 )
90
- return score , len (text )
168
+ source_features = _get_features (
169
+ cast (List [Union [Tensor , str ]], source_data ), source_modality , device , model , processor
170
+ )
171
+ target_features = _get_features (
172
+ cast (List [Union [Tensor , str ]], target_data ), target_modality , device , model , processor
173
+ )
174
+ source_features = source_features / source_features .norm (p = 2 , dim = - 1 , keepdim = True )
175
+ target_features = target_features / target_features .norm (p = 2 , dim = - 1 , keepdim = True )
176
+
177
+ # Calculate cosine similarity
178
+ score = 100 * (source_features * target_features ).sum (axis = - 1 )
179
+ score = score .cpu () if source_modality == "text" and target_modality == "text" else score
180
+ return score , len (source_data )
91
181
92
182
93
183
def _get_clip_model_and_processor (
@@ -113,20 +203,20 @@ def _get_clip_model_and_processor(
113
203
114
204
115
205
def clip_score (
116
- images : Union [Tensor , List [Tensor ]],
117
- text : Union [str , list [ str ]],
206
+ source : Union [Tensor , List [Tensor ], List [ str ], str ],
207
+ target : Union [Tensor , List [ Tensor ], List [ str ], str ],
118
208
model_name_or_path : Literal [
119
209
"openai/clip-vit-base-patch16" ,
120
210
"openai/clip-vit-base-patch32" ,
121
211
"openai/clip-vit-large-patch14-336" ,
122
212
"openai/clip-vit-large-patch14" ,
123
213
] = "openai/clip-vit-large-patch14" ,
124
214
) -> Tensor :
125
- r"""Calculate `CLIP Score`_ which is a text-to-image similarity metric.
215
+ r"""Calculates `CLIP Score`_ which is a text-to-image similarity metric.
126
216
127
217
CLIP Score is a reference free metric that can be used to evaluate the correlation between a generated caption for
128
- an image and the actual content of the image. It has been found to be highly correlated with human judgement. The
129
- metric is defined as:
218
+ an image and the actual content of the image, as well as the similarity between texts or images. It has been found
219
+ to be highly correlated with human judgement. The metric is defined as:
130
220
131
221
.. math::
132
222
\text{CLIPScore(I, C)} = max(100 * cos(E_I, E_C), 0)
@@ -135,15 +225,33 @@ def clip_score(
135
225
textual CLIP embedding :math:`E_C` for an caption :math:`C`. The score is bound between 0 and 100 and the closer
136
226
to 100 the better.
137
227
138
- .. caution::
139
- Metric is not scriptable
228
+ Additionally, the CLIP Score can be calculated for the same modalities:
229
+
230
+ .. math::
231
+ \text{CLIPScore(I_1, I_2)} = max(100 * cos(E_{I_1}, E_{I_2}), 0)
232
+
233
+ where :math:`E_{I_1}` and :math:`E_{I_2}` are the visual embeddings for images :math:`I_1` and :math:`I_2`.
234
+
235
+ .. math::
236
+ \text{CLIPScore(T_1, T_2)} = max(100 * cos(E_{T_1}, E_{T_2}), 0)
237
+
238
+ where :math:`E_{T_1}` and :math:`E_{T_2}` are the textual embeddings for texts :math:`T_1` and :math:`T_2`.
239
+
240
+ .. note:: Metric is not scriptable
140
241
141
242
Args:
142
- images: Either a single [N, C, H, W] tensor or a list of [C, H, W] tensors
143
- text: Either a single caption or a list of captions
144
- model_name_or_path: string indicating the version of the CLIP model to use. Available models are
145
- `"openai/clip-vit-base-patch16"`, `"openai/clip-vit-base-patch32"`, `"openai/clip-vit-large-patch14-336"`
146
- and `"openai/clip-vit-large-patch14"`,
243
+ source: Source input. This can be:
244
+ - Images: Either a single [N, C, H, W] tensor or a list of [C, H, W] tensors.
245
+ - Text: Either a single caption or a list of captions.
246
+ target: Target input. This can be:
247
+ - Images: Either a single [N, C, H, W] tensor or a list of [C, H, W] tensors.
248
+ - Text: Either a single caption or a list of captions.
249
+ model_name_or_path: String indicating the version of the CLIP model to use. Available models are:
250
+ - `"openai/clip-vit-base-patch16"`
251
+ - `"openai/clip-vit-base-patch32"`
252
+ - `"openai/clip-vit-large-patch14-336"`
253
+ - `"openai/clip-vit-large-patch14"`
254
+
147
255
148
256
Raises:
149
257
ModuleNotFoundError:
@@ -155,13 +263,31 @@ def clip_score(
155
263
156
264
Example:
157
265
>>> from torchmetrics.functional.multimodal import clip_score
158
- >>> score = clip_score(torch.randint(255, (3, 224, 224)), "a photo of a cat", "openai/clip-vit-base-patch16")
266
+ >>> image = torch.randint(255, (3, 224, 224), generator=torch.Generator().manual_seed(42))
267
+ >>> score = clip_score(image, "a photo of a cat", "openai/clip-vit-base-patch16")
159
268
>>> score.detach()
160
269
tensor(24.4255)
161
270
271
+ Example:
272
+ >>> from torchmetrics.functional.multimodal import clip_score
273
+ >>> image1 = torch.randint(255, (3, 224, 224), generator=torch.Generator().manual_seed(42))
274
+ >>> image2 = torch.randint(255, (3, 224, 224), generator=torch.Generator().manual_seed(43))
275
+ >>> score = clip_score(image1, image2, "openai/clip-vit-base-patch16")
276
+ >>> score.detach()
277
+ tensor(99.4859)
278
+
279
+ Example:
280
+ >>> from torchmetrics.functional.multimodal import clip_score
281
+ >>> score = clip_score(
282
+ ... "28-year-old chef found dead in San Francisco mall",
283
+ ... "A 28-year-old chef who recently moved to San Francisco was found dead.",
284
+ ... "openai/clip-vit-base-patch16"
285
+ ... )
286
+ >>> score.detach()
287
+ tensor(91.3950)
288
+
162
289
"""
163
290
model , processor = _get_clip_model_and_processor (model_name_or_path )
164
- device = images .device if isinstance (images , Tensor ) else images [0 ].device
165
- score , _ = _clip_score_update (images , text , model .to (device ), processor )
291
+ score , _ = _clip_score_update (source , target , model , processor )
166
292
score = score .mean (0 )
167
293
return torch .max (score , torch .zeros_like (score ))
0 commit comments