-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Automatically doing TensorFloat32 Math on Ampere GPUs #4873
Comments
Thanks for the question! So it sounds like this is a specialized dtype somewhat similar to |
So the TF32 case is simpler than bfloat16 because it's not actually a storage format (so, for instance, we would not have to include TF32 in the JAX type system or make it possible to create a TF32-typed array). Instead, it's a configuration option for matrix multiplies and convolutions with float32 inputs running on an NVIDIA A100 chip. We have a precedent for this: matmuls and convolutions with float32 inputs running on TPU chips can use As a prerequisite, XLA:GPU needs to support setting this configuration option on cuBLAS and cuDNN calls. I think that's WIP, and we'll update when it's ready. |
I'm actually not 100% sure that this isn't already enabled in XLA: the relevant changes have been merged into the TF tree and I believe shared with XLA. But I don't have any Ampere GPUs to test with. Can someone with such a GPU try it out? |
@hawkinsp I tried on an A100 and it looks like it's not yet enabled in TF-XLA. I posted an issue in TensorFlow with the details of how I tested this: tensorflow/tensorflow#44887. |
@n2cholas @hawkinsp I got JAX to work with TF32 and Flax implements mixed-precision training as well. Although JAX didn't explicitly implement TF32 noobs such as call cublasSetMathMode when creating cublas handles. NVIDIA libraries provide an environment variable Simply install JAX into https://ngc.nvidia.com/catalog/containers/nvidia:tensorflow and you got TF32 enabled. I think NGC container is not a must but that is what I tried with. |
Setting |
According to the official documentation, |
@n-gao did you use NVIDIA NGC containers? Please see https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnMathType_t regarding why You should use DCGM to identify TensorCore usage (TF32), try follow 1004 metric, example: https://docs.nvidia.com/datacenter/dcgm/latest/dcgm-user-guide/feature-overview.html#profiling-dcgmproftester That being said, TensorCore usaged depends on model, dimensions, cuda compiler chose the fastest kernel not the latest, so it is very common while TensorCore is enabled, but your model don't have fast kernel implemented with TensorCore |
I do not use the NGC container but should that change anything?
This doesn't imply that it forces TF32 if it's set to 1. I will try DCGM soon. |
Update: with the newest version TF32 is enabled by default: (pn) gaoni@mdsi-gpu02:~$ ipython
Python 3.8.12 (default, Oct 12 2021, 13:49:34)
Type 'copyright', 'credits' or 'license' for more information
IPython 8.0.1 -- An enhanced Interactive Python. Type '?' for help.
In [1]: import jax.numpy as jnp
In [2]: import numpy as np
In [3]: A = jnp.array(np.random.normal(size=(4096, 4096)))
In [4]: %timeit (A@A).block_until_ready()
1.25 ms ± 149 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
(pn) gaoni@mdsi-gpu02:~$ NVIDIA_TF32_OVERRIDE=0 ipython
Python 3.8.12 (default, Oct 12 2021, 13:49:34)
Type 'copyright', 'credits' or 'license' for more information
IPython 8.0.1 -- An enhanced Interactive Python. Type '?' for help.
In [1]: import numpy as np
In [2]: import jax.numpy as jnp
In [3]: A = jnp.array(np.random.normal(size=(4096, 4096)))
In [4]: %timeit (A@A).block_until_ready()
7.56 ms ± 138 µs per loop (mean ± std. dev. of 7 runs, 1 loop each) |
Thanks @n-gao! I get similar results as well with an RTX 3090, CUDA 11.6, CUDNN 8.2: (test-env) nicv@nicv-desktop ~> ipython
Python 3.10.2 (main, Jan 15 2022, 18:02:07) [GCC 9.3.0]
Type 'copyright', 'credits' or 'license' for more information
IPython 8.1.0 -- An enhanced Interactive Python. Type '?' for help.
In [1]: import jax.numpy as jnp; import numpy as np
In [2]: A = jnp.array(np.random.normal(size=(4096, 4096)))
In [3]: %timeit (A@A).block_until_ready()
3.98 ms ± 21.2 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [4]: %timeit (A@A).block_until_ready()
4 ms ± 11.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
In [5]:
Do you really want to exit ([y]/n)? y
(test-env) nicv@nicv-desktop ~> NVIDIA_TF32_OVERRIDE=0 ipython
Python 3.10.2 (main, Jan 15 2022, 18:02:07) [GCC 9.3.0]
Type 'copyright', 'credits' or 'license' for more information
IPython 8.1.0 -- An enhanced Interactive Python. Type '?' for help.
In [1]: import jax.numpy as jnp; import numpy as np
In [2]: A = jnp.array(np.random.normal(size=(4096, 4096)))
In [3]: %timeit (A@A).block_until_ready()
5.64 ms ± 123 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [4]: %timeit (A@A).block_until_ready()
5.94 ms ± 16.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) |
NVIDIA's new Ampere cards support much faster fp32 computation by doing certain operations on TensorCores at a lower precision and accumulating them in fp32. This lower precision format for math is called TensorFloat32 (TF32). TensorFlow 2.4 and PyTorch 1.7 both enable TF32 math by default, and provide flags to disable it.
Can this be supported in JAX as well? Thanks!
The text was updated successfully, but these errors were encountered: