11
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
- from typing import TYPE_CHECKING , List , Union , cast
14
+ from typing import TYPE_CHECKING , Any , Callable , List , Union , cast
15
15
16
16
import torch
17
17
from torch import Tensor
@@ -41,6 +41,24 @@ def _download_clip_for_clip_score() -> None:
41
41
_CLIPProcessor = None
42
42
43
43
44
+ class JinaProcessorWrapper :
45
+ """Wrapper class to convert tensors to PIL images if needed for Jina CLIP model."""
46
+
47
+ def __init__ (self , processor : _CLIPProcessor ) -> None :
48
+ self .processor = processor
49
+
50
+ def __call__ (self , * args : Any , ** kwargs : Any ) -> Any :
51
+ """Wrap the processor's __call__ method to convert tensors to PIL images if needed."""
52
+ # Check if 'images' is in kwargs and convert tensors to PIL images if needed
53
+ from torchvision .transforms .functional import to_pil_image
54
+
55
+ if "images" in kwargs :
56
+ kwargs ["images" ] = [
57
+ to_pil_image (img .float ().cpu ()) if isinstance (img , Tensor ) else img for img in kwargs ["images" ]
58
+ ]
59
+ return self .processor (* args , ** kwargs )
60
+
61
+
44
62
def _detect_modality (input_data : Union [Tensor , List [Tensor ], List [str ], str ]) -> Literal ["image" , "text" ]:
45
63
"""Automatically detect the modality of the input data.
46
64
@@ -110,22 +128,22 @@ def _get_features(
110
128
111
129
"""
112
130
if modality == "image" :
113
- # Add type checking for images
114
- image_data = [i for i in data if isinstance (i , Tensor )]
131
+ image_data = [i for i in data if isinstance (i , Tensor )] # Add type checking for images
115
132
processed = processor (images = [i .cpu () for i in image_data ], return_tensors = "pt" , padding = True )
116
133
return model .get_image_features (processed ["pixel_values" ].to (device ))
117
134
if modality == "text" :
118
135
processed = processor (text = data , return_tensors = "pt" , padding = True )
119
- max_position_embeddings = model .config .text_config .max_position_embeddings
120
- if processed ["attention_mask" ].shape [- 1 ] > max_position_embeddings :
121
- rank_zero_warn (
122
- f"Encountered caption longer than { max_position_embeddings = } . Will truncate captions to this length."
123
- "If longer captions are needed, initialize argument `model_name_or_path` with a model that supports"
124
- "longer sequences" ,
125
- UserWarning ,
126
- )
127
- processed ["attention_mask" ] = processed ["attention_mask" ][..., :max_position_embeddings ]
128
- processed ["input_ids" ] = processed ["input_ids" ][..., :max_position_embeddings ]
136
+ if hasattr (model .config , "text_config" ) and hasattr (model .config .text_config , "max_position_embeddings" ):
137
+ max_position_embeddings = model .config .text_config .max_position_embeddings
138
+ if processed ["attention_mask" ].shape [- 1 ] > max_position_embeddings :
139
+ rank_zero_warn (
140
+ f"Encountered caption longer than { max_position_embeddings = } . Will truncate captions to this"
141
+ "length. If longer captions are needed, initialize argument `model_name_or_path` with a model that"
142
+ "supports longer sequences." ,
143
+ UserWarning ,
144
+ )
145
+ processed ["attention_mask" ] = processed ["attention_mask" ][..., :max_position_embeddings ]
146
+ processed ["input_ids" ] = processed ["input_ids" ][..., :max_position_embeddings ]
129
147
return model .get_text_features (processed ["input_ids" ].to (device ), processed ["attention_mask" ].to (device ))
130
148
raise ValueError (f"invalid modality { modality } " )
131
149
@@ -136,6 +154,7 @@ def _clip_score_update(
136
154
model : _CLIPModel ,
137
155
processor : _CLIPProcessor ,
138
156
) -> tuple [Tensor , int ]:
157
+ """Update function for CLIP Score."""
139
158
source_modality = _detect_modality (source )
140
159
target_modality = _detect_modality (target )
141
160
@@ -181,19 +200,43 @@ def _clip_score_update(
181
200
182
201
183
202
def _get_clip_model_and_processor (
184
- model_name_or_path : Literal [
185
- "openai/clip-vit-base-patch16" ,
186
- "openai/clip-vit-base-patch32" ,
187
- "openai/clip-vit-large-patch14-336" ,
188
- "openai/clip-vit-large-patch14" ,
189
- ] = "openai/clip-vit-large-patch14" ,
203
+ model_name_or_path : Union [
204
+ Literal [
205
+ "openai/clip-vit-base-patch16" ,
206
+ "openai/clip-vit-base-patch32" ,
207
+ "openai/clip-vit-large-patch14-336" ,
208
+ "openai/clip-vit-large-patch14" ,
209
+ "jinaai/jina-clip-v2" ,
210
+ "zer0int/LongCLIP-L-Diffusers" ,
211
+ "zer0int/LongCLIP-GmP-ViT-L-14" ,
212
+ ],
213
+ Callable [[], tuple [_CLIPModel , _CLIPProcessor ]],
214
+ ],
190
215
) -> tuple [_CLIPModel , _CLIPProcessor ]:
216
+ if callable (model_name_or_path ):
217
+ return model_name_or_path ()
218
+
191
219
if _TRANSFORMERS_GREATER_EQUAL_4_10 :
220
+ from transformers import AutoModel , AutoProcessor
221
+ from transformers import CLIPConfig as _CLIPConfig
192
222
from transformers import CLIPModel as _CLIPModel
193
223
from transformers import CLIPProcessor as _CLIPProcessor
194
224
195
- model = _CLIPModel .from_pretrained (model_name_or_path )
196
- processor = _CLIPProcessor .from_pretrained (model_name_or_path )
225
+ if "openai" in model_name_or_path :
226
+ model = _CLIPModel .from_pretrained (model_name_or_path )
227
+ processor = _CLIPProcessor .from_pretrained (model_name_or_path )
228
+ elif "jinaai" in model_name_or_path :
229
+ model = AutoModel .from_pretrained (model_name_or_path , trust_remote_code = True )
230
+ processor = JinaProcessorWrapper (
231
+ processor = AutoProcessor .from_pretrained (model_name_or_path , trust_remote_code = True )
232
+ )
233
+ elif "zer0int" in model_name_or_path :
234
+ config = _CLIPConfig .from_pretrained (model_name_or_path )
235
+ config .text_config .max_position_embeddings = 248
236
+ model = _CLIPModel .from_pretrained (model_name_or_path , config = config )
237
+ processor = _CLIPProcessor .from_pretrained (model_name_or_path , padding = "max_length" , max_length = 248 )
238
+ else :
239
+ raise ValueError (f"Invalid model_name_or_path { model_name_or_path } . Not supported by `clip_score` metric." )
197
240
return model , processor
198
241
199
242
raise ModuleNotFoundError (
@@ -205,11 +248,17 @@ def _get_clip_model_and_processor(
205
248
def clip_score (
206
249
source : Union [Tensor , List [Tensor ], List [str ], str ],
207
250
target : Union [Tensor , List [Tensor ], List [str ], str ],
208
- model_name_or_path : Literal [
209
- "openai/clip-vit-base-patch16" ,
210
- "openai/clip-vit-base-patch32" ,
211
- "openai/clip-vit-large-patch14-336" ,
212
- "openai/clip-vit-large-patch14" ,
251
+ model_name_or_path : Union [
252
+ Literal [
253
+ "openai/clip-vit-base-patch16" ,
254
+ "openai/clip-vit-base-patch32" ,
255
+ "openai/clip-vit-large-patch14-336" ,
256
+ "openai/clip-vit-large-patch14" ,
257
+ "jinaai/jina-clip-v2" ,
258
+ "zer0int/LongCLIP-L-Diffusers" ,
259
+ "zer0int/LongCLIP-GmP-ViT-L-14" ,
260
+ ],
261
+ Callable [[], tuple [_CLIPModel , _CLIPProcessor ]],
213
262
] = "openai/clip-vit-large-patch14" ,
214
263
) -> Tensor :
215
264
r"""Calculates `CLIP Score`_ which is a text-to-image similarity metric.
@@ -239,6 +288,11 @@ def clip_score(
239
288
240
289
.. note:: Metric is not scriptable
241
290
291
+ .. note::
292
+ The default CLIP and processor used in this implementation has a maximum sequence length of 77 for text
293
+ inputs. If you need to process longer captions, you can use the `zer0int/LongCLIP-L-Diffusers` model which
294
+ has a maximum sequence length of 248.
295
+
242
296
Args:
243
297
source: Source input. This can be:
244
298
- Images: Either a single [N, C, H, W] tensor or a list of [C, H, W] tensors.
@@ -251,7 +305,14 @@ def clip_score(
251
305
- `"openai/clip-vit-base-patch32"`
252
306
- `"openai/clip-vit-large-patch14-336"`
253
307
- `"openai/clip-vit-large-patch14"`
254
-
308
+ - `"jinaai/jina-clip-v2"`
309
+ - `"zer0int/LongCLIP-L-Diffusers"`
310
+ - `"zer0int/LongCLIP-GmP-ViT-L-14"`
311
+
312
+ Alternatively, a callable function that returns a tuple of CLIP compatible model and processor instances
313
+ can be passed in. By compatible, we mean that the processors `__call__` method should accept a list of
314
+ strings and list of images and that the model should have a `get_image_features` and `get_text_features`
315
+ methods.
255
316
256
317
Raises:
257
318
ModuleNotFoundError:
0 commit comments