26
26
27
27
28
28
class DeepImageStructureAndTextureSimilarity (Metric ):
29
- """Calculates Deep Image Structure and Texture Similarity (DISTS) score."""
29
+ """Calculates Deep Image Structure and Texture Similarity (DISTS) score.
30
+
31
+ The metric is a full-reference image quality assessment (IQA) model that combines sensitivity to structural
32
+ distortions (e.g., artifacts due to noise, blur, or compression) with a tolerance of texture resampling
33
+ (exchanging the content of a texture region with a new sample of the same texture). The metric is based on
34
+ a convolutional neural network (CNN) that transforms the reference and distorted images to a new representation.
35
+ Within this representation, a set of measurements are developed that are sufficient to capture the appearance
36
+ of a variety of different visual distortions.
37
+
38
+ As input to ``forward`` and ``update`` the metric accepts the following input
39
+
40
+ - ``preds`` (:class:`~torch.Tensor`): tensor with images of shape ``(N, 3, H, W)``
41
+ - ``target`` (:class:`~torch.Tensor`): tensor with images of shape ``(N, 3, H, W)``
42
+
43
+ As output of `forward` and `compute` the metric returns the following output
44
+
45
+ - ``lpips`` (:class:`~torch.Tensor`): returns float scalar tensor with average LPIPS value over samples
46
+
47
+ Args:
48
+ reduction: specifies the reduction to apply to the output.
49
+ kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
50
+
51
+ Raises:
52
+ ValueError:
53
+ If `reduction` is not one of ["mean", "sum"]
54
+
55
+ Example:
56
+ >>> from torch import rand
57
+ >>> from torchmetrics.image.dists import DeepImageStructureAndTextureSimilarity
58
+ >>> metric = DeepImageStructureAndTextureSimilarity()
59
+ >>> preds = rand(10, 3, 100, 100)
60
+ >>> target = rand(10, 3, 100, 100)
61
+ >>> metric(preds, target)
62
+ tensor(0.1882, grad_fn=<CloneBackward0>)
63
+
64
+ """
30
65
31
66
score : Tensor
32
67
total : Tensor
@@ -77,8 +112,8 @@ def plot(
77
112
78
113
>>> # Example plotting a single value
79
114
>>> import torch
80
- >>> from torchmetrics.image.lpip import DeepImageStructureAndTextureSimilarity
81
- >>> metric = DeepImageStructureAndTextureSimilarity(net_type='squeeze' )
115
+ >>> from torchmetrics.image.dists import DeepImageStructureAndTextureSimilarity
116
+ >>> metric = DeepImageStructureAndTextureSimilarity()
82
117
>>> metric.update(torch.rand(10, 3, 100, 100), torch.rand(10, 3, 100, 100))
83
118
>>> fig_, ax_ = metric.plot()
84
119
@@ -87,8 +122,8 @@ def plot(
87
122
88
123
>>> # Example plotting multiple values
89
124
>>> import torch
90
- >>> from torchmetrics.image.lpip import DeepImageStructureAndTextureSimilarity
91
- >>> metric = DeepImageStructureAndTextureSimilarity(net_type='squeeze' )
125
+ >>> from torchmetrics.image.dists import DeepImageStructureAndTextureSimilarity
126
+ >>> metric = DeepImageStructureAndTextureSimilarity()
92
127
>>> values = [ ]
93
128
>>> for _ in range(3):
94
129
... values.append(metric(torch.rand(10, 3, 100, 100), torch.rand(10, 3, 100, 100)))
0 commit comments