Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Jensen-Shannon (JS) Divergence Metric #2947

Open
alifa98 opened this issue Feb 7, 2025 · 1 comment · May be fixed by #2992
Open

Adding Jensen-Shannon (JS) Divergence Metric #2947

alifa98 opened this issue Feb 7, 2025 · 1 comment · May be fixed by #2992
Assignees
Labels
enhancement New feature or request New metric
Milestone

Comments

@alifa98
Copy link

alifa98 commented Feb 7, 2025

🚀 Feature

Given two probability distributions $P$ and $Q$, the JS divergence is defined as:

$$JS(P \mid \mid Q) = \frac{1}{2} KL(P \mid \mid M) + \frac{1}{2} KL(Q \mid \mid M)$$

where:

$KL(P \mid \mid Q)$ is the Kullback-Leibler divergence:

$$ KL(P \mid \mid Q) = \sum_i P(i) . \log \frac{P(i)}{Q(i)} $$

$M$ is the mixture distribution (mean?), defined as

$M = \frac{1}{2} (P + Q)$

Thus, the JS divergence computes the KL divergence between each distribution and the average of both, making it symmetric and bounded in the range:

$0 \leq JS(P \mid \mid Q) \leq \log 2$

Motivation

The PyTorch Metrics library currently includes Kullback-Leibler (KL) Divergence, which is widely used for measuring differences between probability distributions. However, KL divergence is asymmetric and unbounded, making it sometimes less suitable for certain applications.

Jensen-Shannon (JS) Divergence provides a symmetric and smoothed alternative by averaging KL divergence between two distributions and a mixed distribution. This makes it useful for various tasks, including generative modeling, NLP, and probabilistic machine learning.

Adding JS divergence to PyTorchmetrics would align with existing divergence metrics.

Pitch

I propose adding JS Divergence as a new metric in PyTorch Metrics, similar to how KL Divergence is implemented. My implementation is already available in my forked repository, and I’d be happy to refine it based on community feedback.

Alternatives

  • Manual computation: Users can manually compute JS divergence using KL divergence, but having a built-in makes it scalable.
  • Third-party libraries: Some users rely on scipy.spatial.distance.jensenshannon. which I guess is the sqrt of the JSD.

Additional context

N/A

@alifa98 alifa98 added the enhancement New feature or request label Feb 7, 2025
Copy link

github-actions bot commented Feb 7, 2025

Hi! Thanks for your contribution! Great first issue!

@SkafteNicki SkafteNicki added this to the future milestone Feb 24, 2025
@SkafteNicki SkafteNicki modified the milestones: future, v1.7.0 Mar 6, 2025
@SkafteNicki SkafteNicki self-assigned this Mar 6, 2025
@SkafteNicki SkafteNicki linked a pull request Mar 6, 2025 that will close this issue
4 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request New metric
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants