Skip to content

Commit

Permalink
Merge pull request #97 from sp-nitech/ssim
Browse files Browse the repository at this point in the history
Add ssim
  • Loading branch information
takenori-y authored Sep 11, 2024
2 parents fd7fd19 + 57ca27e commit 63c3b21
Show file tree
Hide file tree
Showing 6 changed files with 423 additions and 0 deletions.
82 changes: 82 additions & 0 deletions diffsptk/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1974,6 +1974,88 @@ def spec(
)


def ssim(
x,
y,
reduction="mean",
*,
alpha=1,
beta=1,
gamma=1,
kernel_size=11,
sigma=1.5,
k1=0.01,
k2=0.03,
eps=1e-8,
padding="same",
dynamic_range=None,
):
"""Calculate SSIM.
Parameters
----------
x : Tensor [shape=(..., N, D)]
Input.
y : Tensor [shape=(..., N, D)]
Target.
reduction : ['none', 'mean', 'sum']
Reduction type.
alpha : float > 0
Relative importance of luminance component.
beta : float > 0
Relative importance of contrast component.
gamma : float > 0
Relative importance of structure component.
kernel_size : int >= 1
Kernel size of Gaussian filter.
sigma : float > 0
Standard deviation of Gaussian filter.
k1 : float > 0
A small constant.
k2 : float > 0
A small constant.
eps : float >= 0
A small value to prevent NaN.
padding : ['valid', 'same']
Padding type.
dynamic_range : float > 0 or None
Dynamic range of input. If None, input is automatically normalized.
Returns
-------
out : Tensor [shape=(..., N, D) or scalar]
SSIM or mean SSIM.
"""
return nn.StructuralSimilarityIndex._func(
x,
y,
reduction=reduction,
alpha=alpha,
beta=beta,
gamma=gamma,
kernel_size=kernel_size,
sigma=sigma,
k1=k1,
k2=k2,
eps=eps,
padding=padding,
dynamic_range=dynamic_range,
)


def stft(
x,
*,
Expand Down
2 changes: 2 additions & 0 deletions diffsptk/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@
from .snr import SignalToNoiseRatio
from .snr import SignalToNoiseRatio as SNR
from .spec import Spectrum
from .ssim import StructuralSimilarityIndex
from .ssim import StructuralSimilarityIndex as SSIM
from .stft import ShortTimeFourierTransform
from .stft import ShortTimeFourierTransform as STFT
from .ulaw import MuLawCompression
Expand Down
250 changes: 250 additions & 0 deletions diffsptk/modules/ssim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,250 @@
# ------------------------------------------------------------------------ #
# Copyright 2022 SPTK Working Group #
# #
# Licensed under the Apache License, Version 2.0 (the "License"); #
# you may not use this file except in compliance with the License. #
# You may obtain a copy of the License at #
# #
# http://www.apache.org/licenses/LICENSE-2.0 #
# #
# Unless required by applicable law or agreed to in writing, software #
# distributed under the License is distributed on an "AS IS" BASIS, #
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. #
# See the License for the specific language governing permissions and #
# limitations under the License. #
# ------------------------------------------------------------------------ #

import torch
from torch import nn
import torch.nn.functional as F

from ..misc.utils import to


class StructuralSimilarityIndex(nn.Module):
"""Structural similarity index computation.
Parameters
----------
reduction : ['none', 'mean', 'sum']
Reduction type.
alpha : float > 0
Relative importance of luminance component.
beta : float > 0
Relative importance of contrast component.
gamma : float > 0
Relative importance of structure component.
kernel_size : int >= 1
Kernel size of Gaussian filter.
sigma : float > 0
Standard deviation of Gaussian filter.
k1 : float > 0
A small constant.
k2 : float > 0
A small constant.
eps : float >= 0
A small value to prevent NaN.
padding : ['valid', 'same']
Padding type.
dynamic_range : float > 0 or None
Dynamic range of input. If None, input is automatically normalized.
References
----------
[1] Z. Wang et al., "Image quality assessment: From error visibility to structural
similarity," *IEEE Transactions on Image Processing*, vol. 13, no. 4, pp.
600-612, 2004.
"""

def __init__(
self,
reduction="mean",
*,
alpha=1,
beta=1,
gamma=1,
kernel_size=11,
sigma=1.5,
k1=0.01,
k2=0.03,
eps=1e-8,
padding="same",
dynamic_range=None,
):
super().__init__()

assert reduction in ["none", "mean", "sum"]
assert 1 <= kernel_size and kernel_size % 2 == 1
assert 0 < sigma
assert 0 < k1 < 1
assert 0 < k2 < 1
assert 0 <= eps

self.reduction = reduction
self.weights = (alpha, beta, gamma)
self.ks = (k1, k2)
self.eps = eps
self.padding = padding
self.dynamic_range = dynamic_range
self.register_buffer("kernel", self._precompute(kernel_size, sigma))

def forward(self, x, y):
"""Calculate SSIM.
Parameters
----------
x : Tensor [shape=(..., N, D)]
Input.
y : Tensor [shape=(..., N, D)]
Target.
Returns
-------
out : Tensor [shape=(..., N, D) or scalar]
SSIM or mean SSIM.
Examples
--------
>>> x = diffsptk.nrand(20, 20)
>>> y = diffsptk.nrand(20, 20)
>>> ssim = diffsptk.StructuralSimilarityIndex()
>>> s = ssim(x, y)
>>> s
tensor(0.0588)
"""
return self._forward(
x,
y,
self.reduction,
self.weights,
self.ks,
self.eps,
self.padding,
self.dynamic_range,
self.kernel,
)

@staticmethod
def _forward(x, y, reduction, weights, ks, eps, padding, dynamic_range, kernel):
org_shape = x.shape
x = x.view(-1, 1, x.shape[-2], x.shape[-1])
y = y.view(-1, 1, y.shape[-2], y.shape[-1])

# Normalize x and y to [0, 1].
if dynamic_range is None:
x_max = torch.amax(x, dim=(-2, -1), keepdim=True)
x_min = torch.amin(x, dim=(-2, -1), keepdim=True)
y_max = torch.amax(y, dim=(-2, -1), keepdim=True)
y_min = torch.amin(y, dim=(-2, -1), keepdim=True)
xy_max = torch.maximum(x_max, y_max)
xy_min = torch.minimum(x_min, y_min)
d = xy_max - xy_min
x = (x - xy_min) / d
y = (y - xy_min) / d
dynamic_range = 1

# Pad x and y.
if padding == "valid":
pass
elif padding == "same":
pad_size = kernel.shape[-1] // 2
x = F.pad(x, (pad_size, pad_size, pad_size, pad_size), mode="reflect")
y = F.pad(y, (pad_size, pad_size, pad_size, pad_size), mode="reflect")
else:
raise ValueError(f"padding {padding} is not supported.")

# Set constants.
K1, K2 = ks
L = dynamic_range
C1 = (K1 * L) ** 2
C2 = (K2 * L) ** 2
C3 = 0.5 * C2

# Calculate luminance.
mu_x = F.conv2d(x, kernel, padding=0)
mu_y = F.conv2d(y, kernel, padding=0)
mu2_x = mu_x**2
mu2_y = mu_y**2
luminance = (2 * mu_x * mu_y + C1) / (mu2_x + mu2_y + C1)

# Calculate contrast.
sigma2_x = F.conv2d(x**2, kernel, padding=0) - mu2_x
sigma2_y = F.conv2d(y**2, kernel, padding=0) - mu2_y
sigma_x = torch.sqrt(sigma2_x + eps)
sigma_y = torch.sqrt(sigma2_y + eps)
contrast = (2 * sigma_x * sigma_y + C2) / (sigma2_x + sigma2_y + C2)

# Calculate structure.
mu_xy = mu_x * mu_y
sigma2_xy = F.conv2d(x * y, kernel, padding=0) - mu_xy
structure = (sigma2_xy + C3) / (sigma_x * sigma_y + C3)

# Calculate SSIM.
alpha, beta, gamma = weights
ssim = (luminance**alpha) * (contrast**beta) * (structure**gamma)
ssim = ssim.view(*org_shape[:-2], *ssim.shape[-2:])

if reduction == "none":
pass
elif reduction == "sum":
ssim = ssim.sum()
elif reduction == "mean":
ssim = ssim.mean()
else:
raise ValueError(f"reduction {reduction} is not supported.")
return ssim

@staticmethod
def _func(
x,
y,
reduction,
alpha,
beta,
gamma,
kernel_size,
sigma,
k1,
k2,
eps,
padding,
dynamic_range,
):
kernel = StructuralSimilarityIndex._precompute(
kernel_size, sigma, dtype=x.dtype, device=x.device
)
return StructuralSimilarityIndex._forward(
x,
y,
reduction,
(alpha, beta, gamma),
(k1, k2),
eps,
padding,
dynamic_range,
kernel,
)

@staticmethod
def _precompute(kernel_size, sigma, dtype=None, device=None):
# Generate 2D Gaussian kernel.
center = kernel_size // 2
x = torch.arange(kernel_size, dtype=torch.double, device=device) - center
xx = x**2
G = torch.exp(-0.5 * (xx.unsqueeze(0) + xx.unsqueeze(1)) / sigma**2)
G /= G.sum() # Normalized to unit sum.
G = G.view(1, 1, kernel_size, kernel_size)
return to(G, dtype=dtype)
15 changes: 15 additions & 0 deletions docs/modules/ssim.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
.. _ssim:

ssim
====

.. autoclass:: diffsptk.SSIM

.. autoclass:: diffsptk.StructuralSimilarityIndex
:members:

.. autofunction:: diffsptk.functional.ssim

.. seealso::

:ref:`rmse`
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ dev = [
"pytest",
"pytest-cov",
"ruff",
"scikit-image",
"sphinx",
"twine",
]
Expand Down
Loading

0 comments on commit 63c3b21

Please sign in to comment.