Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatically doing TensorFloat32 Math on Ampere GPUs #4873

Closed
n2cholas opened this issue Nov 11, 2020 · 11 comments
Closed

Automatically doing TensorFloat32 Math on Ampere GPUs #4873

n2cholas opened this issue Nov 11, 2020 · 11 comments
Labels
enhancement New feature or request P2 (eventual) This ought to be addressed, but has no schedule at the moment. (Assignee optional)

Comments

@n2cholas
Copy link
Contributor

NVIDIA's new Ampere cards support much faster fp32 computation by doing certain operations on TensorCores at a lower precision and accumulating them in fp32. This lower precision format for math is called TensorFloat32 (TF32). TensorFlow 2.4 and PyTorch 1.7 both enable TF32 math by default, and provide flags to disable it.

Can this be supported in JAX as well? Thanks!

@froystig froystig added the enhancement New feature or request label Nov 12, 2020
@jakevdp
Copy link
Collaborator

jakevdp commented Nov 12, 2020

Thanks for the question! So it sounds like this is a specialized dtype somewhat similar to bfloat16, which JAX does support, so there is definitely precedent for this sort of thing. That said, for JAX to support TF32, I think XLA would have to support it first, and I'm not certain whether it does. Maybe @jekbradbury might know?

@jekbradbury
Copy link
Contributor

So the TF32 case is simpler than bfloat16 because it's not actually a storage format (so, for instance, we would not have to include TF32 in the JAX type system or make it possible to create a TF32-typed array). Instead, it's a configuration option for matrix multiplies and convolutions with float32 inputs running on an NVIDIA A100 chip. We have a precedent for this: matmuls and convolutions with float32 inputs running on TPU chips can use lax.Precision.DEFAULT (rounding to bfloat16), lax.Precision.HIGH (essentially the same as TF32), and lax.Precision.HIGHEST (full float32 precision). We might want to change the names or otherwise modify the precision API, but that's where TF32 support would go.

As a prerequisite, XLA:GPU needs to support setting this configuration option on cuBLAS and cuDNN calls. I think that's WIP, and we'll update when it's ready.

@hawkinsp
Copy link
Collaborator

I'm actually not 100% sure that this isn't already enabled in XLA: the relevant changes have been merged into the TF tree and I believe shared with XLA. But I don't have any Ampere GPUs to test with. Can someone with such a GPU try it out?

@n2cholas
Copy link
Contributor Author

@hawkinsp I tried on an A100 and it looks like it's not yet enabled in TF-XLA. I posted an issue in TensorFlow with the details of how I tested this: tensorflow/tensorflow#44887.

@mengdong
Copy link

mengdong commented Apr 17, 2021

@n2cholas @hawkinsp I got JAX to work with TF32 and Flax implements mixed-precision training as well. Although JAX didn't explicitly implement TF32 noobs such as call cublasSetMathMode when creating cublas handles.

NVIDIA libraries provide an environment variable NVIDIA_TF32_OVERRIDE=1, see more here. that enable TF32 regardless how the framework call cuda libraries.

Simply install JAX into https://ngc.nvidia.com/catalog/containers/nvidia:tensorflow and you got TF32 enabled. I think NGC container is not a must but that is what I tried with.

@n-gao
Copy link

n-gao commented Aug 18, 2021

Setting NVIDIA_TF32_OVERRIDE=1 does not fix the issue for me. Ther performance is identical with or without. Is there any update on this issue?

@n-gao
Copy link

n-gao commented Aug 20, 2021

According to the official documentation, NVIDIA_TF_OVERRIDE=0 can only disable tensor cores, it cannot enable them.

@skye skye added the P2 (eventual) This ought to be addressed, but has no schedule at the moment. (Assignee optional) label Aug 23, 2021
@mengdong
Copy link

@n-gao did you use NVIDIA NGC containers? Please see https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnMathType_t regarding why NVIDIA_TF32_OVERRIDE=1 should work.

You should use DCGM to identify TensorCore usage (TF32), try follow 1004 metric, example: https://docs.nvidia.com/datacenter/dcgm/latest/dcgm-user-guide/feature-overview.html#profiling-dcgmproftester

That being said, TensorCore usaged depends on model, dimensions, cuda compiler chose the fastest kernel not the latest, so it is very common while TensorCore is enabled, but your model don't have fast kernel implemented with TensorCore

@n-gao
Copy link

n-gao commented Aug 26, 2021

I do not use the NGC container but should that change anything?
It only mentions

The TF32 behavior for CUDNN_DEFAULT_MATH can be explicitly disabled by the environment variable NVIDIA_TF32_OVERRIDE=0.

This doesn't imply that it forces TF32 if it's set to 1.

I will try DCGM soon.

@n-gao
Copy link

n-gao commented Feb 25, 2022

Update: with the newest version TF32 is enabled by default:

(pn) gaoni@mdsi-gpu02:~$ ipython
Python 3.8.12 (default, Oct 12 2021, 13:49:34)
Type 'copyright', 'credits' or 'license' for more information
IPython 8.0.1 -- An enhanced Interactive Python. Type '?' for help.

In [1]: import jax.numpy as jnp

In [2]: import numpy as np

In [3]: A = jnp.array(np.random.normal(size=(4096, 4096)))

In [4]: %timeit (A@A).block_until_ready()
1.25 ms ± 149 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

(pn) gaoni@mdsi-gpu02:~$ NVIDIA_TF32_OVERRIDE=0 ipython
Python 3.8.12 (default, Oct 12 2021, 13:49:34)
Type 'copyright', 'credits' or 'license' for more information
IPython 8.0.1 -- An enhanced Interactive Python. Type '?' for help.

In [1]: import numpy as np

In [2]: import jax.numpy as jnp

In [3]: A = jnp.array(np.random.normal(size=(4096, 4096)))

In [4]: %timeit (A@A).block_until_ready()
7.56 ms ± 138 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

@n2cholas
Copy link
Contributor Author

Thanks @n-gao! I get similar results as well with an RTX 3090, CUDA 11.6, CUDNN 8.2:

(test-env) nicv@nicv-desktop ~> ipython
Python 3.10.2 (main, Jan 15 2022, 18:02:07) [GCC 9.3.0]
Type 'copyright', 'credits' or 'license' for more information
IPython 8.1.0 -- An enhanced Interactive Python. Type '?' for help.

In [1]: import jax.numpy as jnp; import numpy as np

In [2]: A = jnp.array(np.random.normal(size=(4096, 4096)))

In [3]: %timeit (A@A).block_until_ready()
3.98 ms ± 21.2 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [4]: %timeit (A@A).block_until_ready()
4 ms ± 11.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

In [5]:                                                                         
Do you really want to exit ([y]/n)? y
(test-env) nicv@nicv-desktop ~> NVIDIA_TF32_OVERRIDE=0 ipython
Python 3.10.2 (main, Jan 15 2022, 18:02:07) [GCC 9.3.0]
Type 'copyright', 'credits' or 'license' for more information
IPython 8.1.0 -- An enhanced Interactive Python. Type '?' for help.

In [1]: import jax.numpy as jnp; import numpy as np

In [2]: A = jnp.array(np.random.normal(size=(4096, 4096)))

In [3]: %timeit (A@A).block_until_ready()
5.64 ms ± 123 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [4]: %timeit (A@A).block_until_ready()
5.94 ms ± 16.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request P2 (eventual) This ought to be addressed, but has no schedule at the moment. (Assignee optional)
Projects
None yet
Development

No branches or pull requests

8 participants