-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Precision argument for matmul is not respected on Ampere GPUs like the RTX 3090 #14022
Comments
When I tried to run the example, I cannot reproduce it on 3090 but can reproduce it on A100:
This is using the latest development branch.
|
Thanks for giving this a try! The interesting thing is that the RTX 3090 reports zero difference between the float32 and the tensorfloat32 calculation on the GPU while the difference to the CPU solution is on the level of machine precision. So it seems as if the RTX 3090 in your case uses float32 precision no matter what. The A100 on the other hand seems to use tensorfloat32 irrespective of precision settings. |
This should be fixed in the upcoming jaxlib release. |
You might get somewhere with setting: XLA_FLAGS=--xla_gpu_enable_xla_runtime_executable=false until then? |
Thanks! It seems that XLA_FLAGS=--xla_gpu_enable_xla_runtime_executable=false does not resolve the problem on A100 GPUs with JAX 0.4.2. Was the fix done in the new runtime code? |
I can confirm this is fixed in our latest internal build:
|
It's in the latest development branch, so a nightly build or the next release (0.4.3) should have it. |
I confirmed that this issue is fixed now. We can close this bug. |
Description
As mentioned in issue #4873, matrix multiplication defaults to using tensorFloat32 on Ampere GPUs like the RTX 3090.
However, it would be nice to be able to specify whether tensorFloat32 or float32 is used using the
precision
argument of jax numpy's matrix operations. It seeems this is currently not possible and tensorfloat32 will be used irrespective of the value of theprecision
argument. So at the moment, it seems it is only possible to disable tensorfloat32 usage globally by setting theNVIDIA_TF32_OVERRIDE=0
environment variable.See below for a minimum reproduction example:
What jax/jaxlib version are you using?
jax 0.4.1, jaxlib 0.4.1+cuda11.cudnn86
Which accelerator(s) are you using?
GPU, RTX 3090
Additional system info
Ubuntu 20
NVIDIA GPU info
The text was updated successfully, but these errors were encountered: