Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Biquad functions not seemingly compatible with autocast with bfloat16 #3880

Open
pokepress opened this issue Feb 17, 2025 · 0 comments
Open

Comments

@pokepress
Copy link

🐛 Describe the bug

I attempted to use highpass_biquad to calculate a loss while inside of an autocast block:

self.floatFormat = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
...
with torch.autocast(device_type="cuda", dtype=self.floatFormat):
...
                losses = self._get_losses(hr_reprs, pr_reprs)

def _get_losses(self, hr, pr):
...
                pr_highpass = torchaudio.functional.highpass_biquad(pr_time, self.args.experiment.hr_sr, self.args.experiment.hr_sr/10, 1.5)
                hr_highpass = torchaudio.functional.highpass_biquad(hr_time, self.args.experiment.hr_sr, self.args.experiment.hr_sr/10, 1.5)

However, it doesn't seem like it works if bfloat16 is in use:

  File "d:\...\.venv\lib\site-packages\torchaudio\functional\filtering.py", line 922, in highpass_biquad       
    return biquad(waveform, b0, b1, b2, a0, a1, a2)
  File "d:\...\.venv\lib\site-packages\torchaudio\functional\filtering.py", line 327, in biquad
    output_waveform = lfilter(
  File "d:\...\.venv\lib\site-packages\torchaudio\functional\filtering.py", line 1059, in lfilter
    output = _lfilter(waveform, a_coeffs, b_coeffs)
  File "d:\...\.venv\lib\site-packages\torch\_ops.py", line 1123, in __call__
    return self._op(*args, **(kwargs or {}))
RuntimeError: Expected (in.dtype() == torch::kFloat32 || in.dtype() == torch::kFloat64) && (a_flipped.dtype() == torch::kFloat32 || a_flipped.dtype() == torch::kFloat64) && (padded_out.dtype() == torch::kFloat32 || padded_out.dtype() == torch::kFloat64) to be true, but got false.  (Could this error message be improved?  If so, please report an enhancement request to PyTorch.)

Basically, the code is complaining that the float wasn't in one of the formats it expected (32 or 64 bit). My understanding is that autocast is supposed to handle this itself, but for some reason it isn't. I did a little debugging and it looks like the code in filtering.py is treating the data as a 32-bit float in this case, so I can only assume the data isn't passed correctly to the native code(?) itself. I got around this by using float32 if this particular loss is active.

This is on 2.6.0/0.21.0, but I had the same problem in 2.4.1. Windows 10.

Versions

Couldn't get the last line of that code to work. Sorry.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant