Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Newmetric: DISTS #2993

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
21 changes: 21 additions & 0 deletions docs/source/image/dists.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
.. customcarditem::
:header: D
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/image_classification.svg
:tags: Image

.. include:: ../links.rst

#################################################
Learned Perceptual Image Patch Similarity (LPIPS)
#################################################

Module Interface
________________

.. autoclass:: torchmetrics.image.lpip.LearnedPerceptualImagePatchSimilarity
:exclude-members: update, compute

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.image.learned_perceptual_image_patch_similarity
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
.. customcarditem::
:header: Learned Perceptual Image Patch Similarity (LPIPS)
:header: Deep Image Structure And Texture Similarity (DISTS)
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/image_classification.svg
:tags: Image

.. include:: ../links.rst

#################################################
Learned Perceptual Image Patch Similarity (LPIPS)
#################################################
###################################################
Deep Image Structure And Texture Similarity (DISTS)
###################################################

Module Interface
________________

.. autoclass:: torchmetrics.image.lpip.LearnedPerceptualImagePatchSimilarity
.. autoclass:: torchmetrics.image.lpip.DeepImageStructureAndTextureSimilarity
:exclude-members: update, compute

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.image.learned_perceptual_image_patch_similarity
.. autofunction:: torchmetrics.functional.image.deep_image_structure_and_texture_similarity
1 change: 1 addition & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -180,5 +180,6 @@
.. _Log AUC: https://pubmed.ncbi.nlm.nih.gov/20735049/
.. _Negative Predictive Value: https://en.wikipedia.org/wiki/Positive_and_negative_predictive_values
.. _ARNIQA: https://arxiv.org/abs/2310.14918
.. _Deep Image Structure and Texture Similarity: https://ieeexplore.ieee.org/abstract/document/9298952
.. _KonIQ-10k: https://database.mmsp-kn.de/koniq-10k-database.html
.. _KADID-10k: https://database.mmsp-kn.de/kadid-10k-database.html
1 change: 1 addition & 0 deletions requirements/image_test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ sewar >=0.4.4, <=0.4.6
numpy <2.3.0
torch-fidelity @ git+https://github.com/toshas/torch-fidelity@master
lpips <=0.1.4
dists-pytorch
2 changes: 2 additions & 0 deletions src/torchmetrics/functional/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from torchmetrics.functional.image.arniqa import arniqa
from torchmetrics.functional.image.d_lambda import spectral_distortion_index
from torchmetrics.functional.image.d_s import spatial_distortion_index
from torchmetrics.functional.image.dists import deep_image_structure_and_texture_similarity
from torchmetrics.functional.image.ergas import error_relative_global_dimensionless_synthesis
from torchmetrics.functional.image.gradients import image_gradients
from torchmetrics.functional.image.lpips import learned_perceptual_image_patch_similarity
Expand All @@ -35,6 +36,7 @@

__all__ = [
"arniqa",
"deep_image_structure_and_texture_similarity",
"error_relative_global_dimensionless_synthesis",
"image_gradients",
"learned_perceptual_image_patch_similarity",
Expand Down
201 changes: 201 additions & 0 deletions src/torchmetrics/functional/image/dists.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Below is a derivative work based on the original work:
# https://github.com/dingkeyan93/DISTS
# with the following license:
#
# MIT License
# Copyright (c) 2020 Keyan Ding
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import inspect
import os
from typing import List, Optional

import numpy as np
import torch
import torch.nn as nn
from torch import Tensor
from torch.nn.functional import conv2d
from torchvision import models
from typing_extensions import Literal


class L2pooling(nn.Module):
"""L2 pooling layer."""

filter: Tensor

def __init__(self, filter_size: int = 5, stride: int = 2, channels: int = 3) -> None:
super().__init__()
self.padding = (filter_size - 2) // 2
self.stride = stride
self.channels = channels
a = np.hanning(filter_size)[1:-1]
g = torch.Tensor(a[:, None] * a[None, :])
g = g / torch.sum(g)
self.register_buffer("filter", g[None, None, :, :].repeat(self.channels, 1, 1, 1))

def forward(self, tensor: Tensor) -> Tensor:
"""Forward pass of the layer."""
tensor = tensor**2
out = conv2d(tensor, self.filter, stride=self.stride, padding=self.padding, groups=tensor.shape[1])
return (out + 1e-12).sqrt()


class DISTSNetwork(torch.nn.Module):
"""DISTS network."""

alpha: Tensor
beta: Tensor
mean: Tensor
std: Tensor

def __init__(self, load_weights: bool = True) -> None:
super().__init__()
vgg_pretrained_features = models.vgg16(pretrained=True).features
self.stage1 = torch.nn.Sequential()
self.stage2 = torch.nn.Sequential()
self.stage3 = torch.nn.Sequential()
self.stage4 = torch.nn.Sequential()
self.stage5 = torch.nn.Sequential()
for x in range(4):
self.stage1.add_module(str(x), vgg_pretrained_features[x])
self.stage2.add_module(str(4), L2pooling(channels=64))
for x in range(5, 9):
self.stage2.add_module(str(x), vgg_pretrained_features[x])
self.stage3.add_module(str(9), L2pooling(channels=128))
for x in range(10, 16):
self.stage3.add_module(str(x), vgg_pretrained_features[x])
self.stage4.add_module(str(16), L2pooling(channels=256))
for x in range(17, 23):
self.stage4.add_module(str(x), vgg_pretrained_features[x])
self.stage5.add_module(str(23), L2pooling(channels=512))
for x in range(24, 30):
self.stage5.add_module(str(x), vgg_pretrained_features[x])

for param in self.parameters():
param.requires_grad = False

self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1, -1, 1, 1))
self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225]).view(1, -1, 1, 1))

self.chns = [3, 64, 128, 256, 512, 512]
self.register_parameter("alpha", nn.Parameter(torch.randn(1, sum(self.chns), 1, 1)))
self.register_parameter("beta", nn.Parameter(torch.randn(1, sum(self.chns), 1, 1)))
self.alpha.data.normal_(0.1, 0.01)
self.beta.data.normal_(0.1, 0.01)
if load_weights:
path = os.path.abspath(os.path.join(inspect.getfile(self.__init__), "..", "dists_models/weights.pt")) # type: ignore
weights = torch.load(path)
self.alpha.data = weights["alpha"]
self.beta.data = weights["beta"]

def forward_once(self, x: Tensor) -> List[Tensor]:
"""Forward pass of the network."""
h = (x - self.mean) / self.std
h = self.stage1(h)
h_relu1_2 = h
h = self.stage2(h)
h_relu2_2 = h
h = self.stage3(h)
h_relu3_3 = h
h = self.stage4(h)
h_relu4_3 = h
h = self.stage5(h)
h_relu5_3 = h
return [x, h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3]

def forward(self, x: Tensor, y: Tensor, require_grad: bool = False) -> Tensor:
"""Computes DISTS score between two images."""
if require_grad:
feats0 = self.forward_once(x)
feats1 = self.forward_once(y)
else:
with torch.inference_mode():
feats0 = self.forward_once(x)
feats1 = self.forward_once(y)
dist1, dist2, c1, c2 = 0, 0, 1e-6, 1e-6
w_sum = self.alpha.sum() + self.beta.sum()
alpha = torch.split(self.alpha / w_sum, self.chns, dim=1)
beta = torch.split(self.beta / w_sum, self.chns, dim=1)
for k in range(len(self.chns)):
x_mean = feats0[k].mean([2, 3], keepdim=True)
y_mean = feats1[k].mean([2, 3], keepdim=True)
s1 = (2 * x_mean * y_mean + c1) / (x_mean**2 + y_mean**2 + c1)
dist1 = dist1 + (alpha[k] * s1).sum(1, keepdim=True)

x_var = ((feats0[k] - x_mean) ** 2).mean([2, 3], keepdim=True)
y_var = ((feats1[k] - y_mean) ** 2).mean([2, 3], keepdim=True)
xy_cov = (feats0[k] * feats1[k]).mean([2, 3], keepdim=True) - x_mean * y_mean
s2 = (2 * xy_cov + c2) / (x_var + y_var + c2)
dist2 = dist2 + (beta[k] * s2).sum(1, keepdim=True)

return 1 - (dist1 + dist2).squeeze()


def _dists_update(preds: Tensor, target: Tensor) -> Tensor:
dists = DISTSNetwork().to(preds.device)
return dists(preds, target, require_grad=preds.requires_grad)


def _dists_compute(scores: Tensor, reduction: Optional[Literal["sum", "mean", "none"]]) -> Tensor:
if reduction == "sum":
return scores.sum()
if reduction == "mean":
return scores.mean()
if reduction is None or reduction == "none":
return scores
raise ValueError(f"Argument {reduction} is not valid. Choose 'sum', 'mean' or 'none'., but got {reduction}")


def deep_image_structure_and_texture_similarity(
preds: Tensor, target: Tensor, reduction: Optional[Literal["sum", "mean", "none"]] = None
) -> Tensor:
"""Calculates `Deep Image Structure and Texture Similarity`_ (DISTS) score.

Args:
preds: Predicted image tensor.
target: Target image tensor.
reduction: Reduction method for the output.

Returns:
DISTS Similarity score between the two images.

Example:
>>> from torch import rand
>>> preds = rand(5, 3, 256, 256)
>>> target = rand(5, 3, 256, 256)
>>> deep_image_structure_and_texture_similarity(preds, target)
tensor([0.1285, 0.1344, 0.1356, 0.1277, 0.1276], grad_fn=<RsubBackward1>)
>>> deep_image_structure_and_texture_similarity(preds, target, reduction='mean')
tensor(0.1308, grad_fn=<MeanBackward0>)

"""
scores = _dists_update(preds, target)
return _dists_compute(scores, reduction)
Binary file not shown.
2 changes: 2 additions & 0 deletions src/torchmetrics/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
from torchmetrics.image.d_lambda import SpectralDistortionIndex
from torchmetrics.image.d_s import SpatialDistortionIndex
from torchmetrics.image.dists import DeepImageStructureAndTextureSimilarity
from torchmetrics.image.ergas import ErrorRelativeGlobalDimensionlessSynthesis
from torchmetrics.image.mifid import MemorizationInformedFrechetInceptionDistance
from torchmetrics.image.psnr import PeakSignalNoiseRatio
Expand All @@ -32,6 +33,7 @@
)

__all__ = [
"DeepImageStructureAndTextureSimilarity",
"ErrorRelativeGlobalDimensionlessSynthesis",
"MemorizationInformedFrechetInceptionDistance",
"MultiScaleStructuralSimilarityIndexMeasure",
Expand Down
Loading
Loading