Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A differentiable pytorch VMAF implementation. #2961

Open
NilanEkanayake opened this issue Feb 18, 2025 · 1 comment · May be fixed by #2991
Open

A differentiable pytorch VMAF implementation. #2961

NilanEkanayake opened this issue Feb 18, 2025 · 1 comment · May be fixed by #2991
Labels
enhancement New feature or request New metric
Milestone

Comments

@NilanEkanayake
Copy link

NilanEkanayake commented Feb 18, 2025

🚀 Feature

See: https://github.com/alvitrioliks/VMAF-torch
I've tested it and it seemingly works as advertised! Doesn't require the original VMAF binaries.

Motivation

VMAF (see https://github.com/Netflix/vmaf) is a reference-based quality assessment metric for videos (and images), but has so far been tied to the original C implementation and wrappers thereof, and could not be directly used as a loss function.
The implementation linked above fixes this.

Pitch

It should be easy enough to add, I'm busy ATM with other projects but if there's nobody else interested in adding it and nothing else in the way, I can try to send a PR for it. Not sure if the BSD 3-Clause License of the code would cause issues, but I can't see why.

Additional context

VMAF is sensitive to numerical precision (bf16 vs fp32), so watch out for that.
If someone is interested in adding it to their project, here's the code I use (I also combine with MSE loss rather than use VMAF solo):

import torch
from torch import nn
from vmaf_torch import VMAF
from einops import rearrange

class VMAFLoss(nn.Module):
    def __init__(self):
        super().__init__()

        self.vmaf = VMAF(temporal_pooling=True, enable_motion=True, clip_score=True, NEG=False).to(torch.float32)

        for param in self.vmaf.parameters():
            param.requires_grad = False

    def get_luma(self, video):
        r = video[..., 0, :, :]
        g = video[..., 1, :, :]
        b = video[..., 2, :, :]

        y = 0.299 * r + 0.587 * g + 0.114 * b

        return y.unsqueeze(1) * 255 # 0-255 for vmaf

    def forward(self, dist, ref):
        orig_dtype = dist.dtype

        dist = (dist.clamp(-1, 1).to(torch.float32) + 1) / 2 # [-1, 1] -> [0, 1]
        ref = (ref.clamp(-1, 1).to(torch.float32) + 1) / 2

        dist_luma = self.get_luma(dist)
        ref_luma = self.get_luma(ref)

        vmaf_loss = 1 - self.vmaf(rearrange(ref_luma, "b c t h w -> (b t) c h w"), rearrange(dist_luma, "b c t h w -> (b t) c h w")) / 100 # should iterate over batched videos instead?

        return vmaf_loss.to(orig_dtype)
@NilanEkanayake NilanEkanayake added the enhancement New feature or request label Feb 18, 2025
Copy link

Hi! Thanks for your contribution! Great first issue!

@SkafteNicki SkafteNicki added this to the future milestone Feb 25, 2025
@SkafteNicki SkafteNicki linked a pull request Mar 5, 2025 that will close this issue
4 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request New metric
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants