Skip to content

Commit 5ff4740

Browse files
committed
improve docstring
1 parent 4ae61a4 commit 5ff4740

File tree

1 file changed

+40
-5
lines changed

1 file changed

+40
-5
lines changed

src/torchmetrics/image/dists.py

+40-5
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,42 @@
2626

2727

2828
class DeepImageStructureAndTextureSimilarity(Metric):
29-
"""Calculates Deep Image Structure and Texture Similarity (DISTS) score."""
29+
"""Calculates Deep Image Structure and Texture Similarity (DISTS) score.
30+
31+
The metric is a full-reference image quality assessment (IQA) model that combines sensitivity to structural
32+
distortions (e.g., artifacts due to noise, blur, or compression) with a tolerance of texture resampling
33+
(exchanging the content of a texture region with a new sample of the same texture). The metric is based on
34+
a convolutional neural network (CNN) that transforms the reference and distorted images to a new representation.
35+
Within this representation, a set of measurements are developed that are sufficient to capture the appearance
36+
of a variety of different visual distortions.
37+
38+
As input to ``forward`` and ``update`` the metric accepts the following input
39+
40+
- ``preds`` (:class:`~torch.Tensor`): tensor with images of shape ``(N, 3, H, W)``
41+
- ``target`` (:class:`~torch.Tensor`): tensor with images of shape ``(N, 3, H, W)``
42+
43+
As output of `forward` and `compute` the metric returns the following output
44+
45+
- ``lpips`` (:class:`~torch.Tensor`): returns float scalar tensor with average LPIPS value over samples
46+
47+
Args:
48+
reduction: specifies the reduction to apply to the output.
49+
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
50+
51+
Raises:
52+
ValueError:
53+
If `reduction` is not one of ["mean", "sum"]
54+
55+
Example:
56+
>>> from torch import rand
57+
>>> from torchmetrics.image.dists import DeepImageStructureAndTextureSimilarity
58+
>>> metric = DeepImageStructureAndTextureSimilarity()
59+
>>> preds = rand(10, 3, 100, 100)
60+
>>> target = rand(10, 3, 100, 100)
61+
>>> metric(preds, target)
62+
tensor(0.1882, grad_fn=<CloneBackward0>)
63+
64+
"""
3065

3166
score: Tensor
3267
total: Tensor
@@ -77,8 +112,8 @@ def plot(
77112
78113
>>> # Example plotting a single value
79114
>>> import torch
80-
>>> from torchmetrics.image.lpip import DeepImageStructureAndTextureSimilarity
81-
>>> metric = DeepImageStructureAndTextureSimilarity(net_type='squeeze')
115+
>>> from torchmetrics.image.dists import DeepImageStructureAndTextureSimilarity
116+
>>> metric = DeepImageStructureAndTextureSimilarity()
82117
>>> metric.update(torch.rand(10, 3, 100, 100), torch.rand(10, 3, 100, 100))
83118
>>> fig_, ax_ = metric.plot()
84119
@@ -87,8 +122,8 @@ def plot(
87122
88123
>>> # Example plotting multiple values
89124
>>> import torch
90-
>>> from torchmetrics.image.lpip import DeepImageStructureAndTextureSimilarity
91-
>>> metric = DeepImageStructureAndTextureSimilarity(net_type='squeeze')
125+
>>> from torchmetrics.image.dists import DeepImageStructureAndTextureSimilarity
126+
>>> metric = DeepImageStructureAndTextureSimilarity()
92127
>>> values = [ ]
93128
>>> for _ in range(3):
94129
... values.append(metric(torch.rand(10, 3, 100, 100), torch.rand(10, 3, 100, 100)))

0 commit comments

Comments
 (0)