Skip to content

Commit 34fca16

Browse files
authored
newmetric: Jensen-Shannon Divergence (#2992)
* implementation * tests * plot testing * docs page * changelog * fixes
1 parent d6a1ad2 commit 34fca16

File tree

8 files changed

+452
-0
lines changed

8 files changed

+452
-0
lines changed

CHANGELOG.md

+3
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1212

1313
### Added
1414

15+
- Added `JensenShannonDivergence` metric to regression package ([#2992](https://github.com/Lightning-AI/torchmetrics/pull/2992))
16+
17+
1518
- Added `ClusterAccuracy` metric to cluster package ([#2777](https://github.com/Lightning-AI/torchmetrics/pull/2777))
1619

1720

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
.. customcarditem::
2+
:header: Jensen-Shannon Divergence
3+
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/tabular_classification.svg
4+
:tags: Regression
5+
6+
.. include:: ../links.rst
7+
8+
#########################
9+
Jensen-Shannon Divergence
10+
#########################
11+
12+
Module Interface
13+
________________
14+
15+
.. autoclass:: torchmetrics.regression.JensenShannonDivergence
16+
:exclude-members: update, compute
17+
18+
Functional Interface
19+
____________________
20+
21+
.. autofunction:: torchmetrics.functional.regression.jensen_shannon_divergence

src/torchmetrics/functional/regression/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from torchmetrics.functional.regression.cosine_similarity import cosine_similarity
1616
from torchmetrics.functional.regression.csi import critical_success_index
1717
from torchmetrics.functional.regression.explained_variance import explained_variance
18+
from torchmetrics.functional.regression.js_divergence import jensen_shannon_divergence
1819
from torchmetrics.functional.regression.kendall import kendall_rank_corrcoef
1920
from torchmetrics.functional.regression.kl_divergence import kl_divergence
2021
from torchmetrics.functional.regression.log_cosh import log_cosh_error
@@ -37,6 +38,7 @@
3738
"cosine_similarity",
3839
"critical_success_index",
3940
"explained_variance",
41+
"jensen_shannon_divergence",
4042
"kendall_rank_corrcoef",
4143
"kl_divergence",
4244
"log_cosh_error",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# Copyright The Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Union
16+
17+
import torch
18+
from torch import Tensor
19+
from typing_extensions import Literal
20+
21+
from torchmetrics.functional.regression.kl_divergence import kl_divergence
22+
from torchmetrics.utilities.checks import _check_same_shape
23+
24+
25+
def _jsd_update(p: Tensor, q: Tensor, log_prob: bool) -> tuple[Tensor, int]:
26+
"""Update and returns jensen-shannon divergence scores for each observation and the total number of observations.
27+
28+
Args:
29+
p: data distribution with shape ``[N, d]``
30+
q: prior or approximate distribution with shape ``[N, d]``
31+
log_prob: bool indicating if input is log-probabilities or probabilities. If given as probabilities,
32+
will normalize to make sure the distributes sum to 1
33+
34+
"""
35+
_check_same_shape(p, q)
36+
if p.ndim != 2 or q.ndim != 2:
37+
raise ValueError(f"Expected both p and q distribution to be 2D but got {p.ndim} and {q.ndim} respectively")
38+
39+
total = p.shape[0]
40+
if log_prob:
41+
mean = torch.logsumexp(torch.stack([p, q]), dim=0) - torch.log(torch.tensor(2.0))
42+
measures = 0.5 * kl_divergence(p, mean, log_prob=log_prob, reduction=None) + 0.5 * kl_divergence(
43+
q, mean, log_prob=log_prob, reduction=None
44+
)
45+
else:
46+
p = p / p.sum(axis=-1, keepdim=True) # type: ignore[call-overload]
47+
q = q / q.sum(axis=-1, keepdim=True) # type: ignore[call-overload]
48+
mean = (p + q) / 2
49+
measures = 0.5 * kl_divergence(p, mean, log_prob=log_prob, reduction=None) + 0.5 * kl_divergence(
50+
q, mean, log_prob=log_prob, reduction=None
51+
)
52+
return measures, total
53+
54+
55+
def _jsd_compute(
56+
measures: Tensor, total: Union[int, Tensor], reduction: Literal["mean", "sum", "none", None] = "mean"
57+
) -> Tensor:
58+
"""Compute and reduce the Jensen-Shannon divergence based on the type of reduction."""
59+
if reduction == "sum":
60+
return measures.sum()
61+
if reduction == "mean":
62+
return measures.sum() / total
63+
if reduction is None or reduction == "none":
64+
return measures
65+
return measures / total
66+
67+
68+
def jensen_shannon_divergence(
69+
p: Tensor, q: Tensor, log_prob: bool = False, reduction: Literal["mean", "sum", "none", None] = "mean"
70+
) -> Tensor:
71+
r"""Compute `Jensen-Shannon divergence`_.
72+
73+
.. math::
74+
D_{JS}(P||Q) = \frac{1}{2} D_{KL}(P||M) + \frac{1}{2} D_{KL}(Q||M)
75+
76+
Where :math:`P` and :math:`Q` are probability distributions where :math:`P` usually represents a distribution
77+
over data and :math:`Q` is often a prior or approximation of :math:`P`. :math:`D_{KL}` is the `KL divergence`_ and
78+
:math:`M` is the average of the two distributions. It should be noted that the Jensen-Shannon divergence is a
79+
symmetrical metric i.e. :math:`D_{JS}(P||Q) = D_{JS}(Q||P)`.
80+
81+
Args:
82+
p: data distribution with shape ``[N, d]``
83+
q: prior or approximate distribution with shape ``[N, d]``
84+
log_prob: bool indicating if input is log-probabilities or probabilities. If given as probabilities,
85+
will normalize to make sure the distributes sum to 1
86+
reduction:
87+
Determines how to reduce over the ``N``/batch dimension:
88+
89+
- ``'mean'`` [default]: Averages score across samples
90+
- ``'sum'``: Sum score across samples
91+
- ``'none'`` or ``None``: Returns score per sample
92+
93+
Example:
94+
>>> from torch import tensor
95+
>>> p = tensor([[0.36, 0.48, 0.16]])
96+
>>> q = tensor([[1/3, 1/3, 1/3]])
97+
>>> jensen_shannon_divergence(p, q)
98+
tensor(0.0225)
99+
100+
"""
101+
measures, total = _jsd_update(p, q, log_prob)
102+
return _jsd_compute(measures, total, reduction)

src/torchmetrics/regression/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from torchmetrics.regression.cosine_similarity import CosineSimilarity
1616
from torchmetrics.regression.csi import CriticalSuccessIndex
1717
from torchmetrics.regression.explained_variance import ExplainedVariance
18+
from torchmetrics.regression.js_divergence import JensenShannonDivergence
1819
from torchmetrics.regression.kendall import KendallRankCorrCoef
1920
from torchmetrics.regression.kl_divergence import KLDivergence
2021
from torchmetrics.regression.log_cosh import LogCoshError
@@ -37,6 +38,7 @@
3738
"CosineSimilarity",
3839
"CriticalSuccessIndex",
3940
"ExplainedVariance",
41+
"JensenShannonDivergence",
4042
"KLDivergence",
4143
"KendallRankCorrCoef",
4244
"LogCoshError",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
# Copyright The Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from math import log
15+
from typing import Any, List, Optional, Sequence, Union, cast
16+
17+
import torch
18+
from torch import Tensor
19+
from typing_extensions import Literal
20+
21+
from torchmetrics.functional.regression.js_divergence import _jsd_compute, _jsd_update
22+
from torchmetrics.metric import Metric
23+
from torchmetrics.utilities.data import dim_zero_cat
24+
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
25+
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
26+
27+
if not _MATPLOTLIB_AVAILABLE:
28+
__doctest_skip__ = ["JensenShannonDivergence.plot"]
29+
30+
31+
class JensenShannonDivergence(Metric):
32+
r"""Compute the `Jensen-Shannon divergence`_.
33+
34+
.. math::
35+
D_{JS}(P||Q) = \frac{1}{2} D_{KL}(P||M) + \frac{1}{2} D_{KL}(Q||M)
36+
37+
Where :math:`P` and :math:`Q` are probability distributions where :math:`P` usually represents a distribution
38+
over data and :math:`Q` is often a prior or approximation of :math:`P`. :math:`D_{KL}` is the `KL divergence`_ and
39+
:math:`M` is the average of the two distributions. It should be noted that the Jensen-Shannon divergence is a
40+
symmetrical metric i.e. :math:`D_{JS}(P||Q) = D_{JS}(Q||P)`.
41+
42+
As input to ``forward`` and ``update`` the metric accepts the following input:
43+
44+
- ``p`` (:class:`~torch.Tensor`): a data distribution with shape ``(N, d)``
45+
- ``q`` (:class:`~torch.Tensor`): prior or approximate distribution with shape ``(N, d)``
46+
47+
As output of ``forward`` and ``compute`` the metric returns the following output:
48+
49+
- ``js_divergence`` (:class:`~torch.Tensor`): A tensor with the Jensen-Shannon divergence
50+
51+
Args:
52+
log_prob: bool indicating if input is log-probabilities or probabilities. If given as probabilities,
53+
will normalize to make sure the distributes sum to 1.
54+
reduction:
55+
Determines how to reduce over the ``N``/batch dimension:
56+
57+
- ``'mean'`` [default]: Averages score across samples
58+
- ``'sum'``: Sum score across samples
59+
- ``'none'`` or ``None``: Returns score per sample
60+
61+
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
62+
63+
Raises:
64+
TypeError:
65+
If ``log_prob`` is not an ``bool``.
66+
ValueError:
67+
If ``reduction`` is not one of ``'mean'``, ``'sum'``, ``'none'`` or ``None``.
68+
69+
.. attention::
70+
Half precision is only support on GPU for this metric.
71+
72+
Example:
73+
>>> from torch import tensor
74+
>>> from torchmetrics.regression import JensenShannonDivergence
75+
>>> p = tensor([[0.1, 0.9], [0.2, 0.8], [0.3, 0.7]])
76+
>>> q = tensor([[0.3, 0.7], [0.4, 0.6], [0.5, 0.5]])
77+
>>> js_div = JensenShannonDivergence()
78+
>>> js_div(p, q)
79+
tensor(0.0259)
80+
81+
"""
82+
83+
is_differentiable: bool = True
84+
higher_is_better: bool = False
85+
full_state_update: bool = False
86+
plot_lower_bound: float = 0.0
87+
plot_upper_bound: float = log(2)
88+
89+
measures: Union[Tensor, List[Tensor]]
90+
total: Tensor
91+
92+
def __init__(
93+
self,
94+
log_prob: bool = False,
95+
reduction: Literal["mean", "sum", "none", None] = "mean",
96+
**kwargs: Any,
97+
) -> None:
98+
super().__init__(**kwargs)
99+
if not isinstance(log_prob, bool):
100+
raise TypeError(f"Expected argument `log_prob` to be bool but got {log_prob}")
101+
self.log_prob = log_prob
102+
103+
allowed_reduction = ["mean", "sum", "none", None]
104+
if reduction not in allowed_reduction:
105+
raise ValueError(f"Expected argument `reduction` to be one of {allowed_reduction} but got {reduction}")
106+
self.reduction = reduction
107+
108+
if self.reduction in ["mean", "sum"]:
109+
self.add_state("measures", torch.tensor(0.0), dist_reduce_fx="sum")
110+
else:
111+
self.add_state("measures", [], dist_reduce_fx="cat")
112+
self.add_state("total", torch.tensor(0), dist_reduce_fx="sum")
113+
114+
def update(self, p: Tensor, q: Tensor) -> None:
115+
"""Update the metric state."""
116+
measures, total = _jsd_update(p, q, self.log_prob)
117+
if self.reduction is None or self.reduction == "none":
118+
cast(List[Tensor], self.measures).append(measures)
119+
else:
120+
self.measures = cast(Tensor, self.measures) + measures.sum()
121+
self.total += total
122+
123+
def compute(self) -> Tensor:
124+
"""Compute metric."""
125+
measures: Tensor = (
126+
dim_zero_cat(cast(List[Tensor], self.measures))
127+
if self.reduction in ["none", None]
128+
else cast(Tensor, self.measures)
129+
)
130+
return _jsd_compute(measures, self.total, self.reduction)
131+
132+
def plot(
133+
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
134+
) -> _PLOT_OUT_TYPE:
135+
"""Plot a single or multiple values from the metric.
136+
137+
Args:
138+
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
139+
If no value is provided, will automatically call `metric.compute` and plot that result.
140+
ax: An matplotlib axis object. If provided will add plot to that axis
141+
142+
Returns:
143+
Figure and Axes object
144+
145+
Raises:
146+
ModuleNotFoundError:
147+
If `matplotlib` is not installed
148+
149+
.. plot::
150+
:scale: 75
151+
152+
>>> from torch import randn
153+
>>> # Example plotting a single value
154+
>>> from torchmetrics.regression import KLDivergence
155+
>>> metric = KLDivergence()
156+
>>> metric.update(randn(10,3).softmax(dim=-1), randn(10,3).softmax(dim=-1))
157+
>>> fig_, ax_ = metric.plot()
158+
159+
.. plot::
160+
:scale: 75
161+
162+
>>> from torch import randn
163+
>>> # Example plotting multiple values
164+
>>> from torchmetrics.regression import KLDivergence
165+
>>> metric = KLDivergence()
166+
>>> values = []
167+
>>> for _ in range(10):
168+
... values.append(metric(randn(10,3).softmax(dim=-1), randn(10,3).softmax(dim=-1)))
169+
>>> fig, ax = metric.plot(values)
170+
171+
"""
172+
return self._plot(val, ax)

0 commit comments

Comments
 (0)