Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Precision argument for matmul is not respected on Ampere GPUs like the RTX 3090 #14022

Closed
jaschau opened this issue Jan 15, 2023 · 8 comments
Closed
Assignees
Labels
bug Something isn't working NVIDIA GPU Issues specific to NVIDIA GPUs

Comments

@jaschau
Copy link

jaschau commented Jan 15, 2023

Description

As mentioned in issue #4873, matrix multiplication defaults to using tensorFloat32 on Ampere GPUs like the RTX 3090.
However, it would be nice to be able to specify whether tensorFloat32 or float32 is used using the precision argument of jax numpy's matrix operations. It seeems this is currently not possible and tensorfloat32 will be used irrespective of the value of the precision argument. So at the moment, it seems it is only possible to disable tensorfloat32 usage globally by setting the NVIDIA_TF32_OVERRIDE=0 environment variable.

See below for a minimum reproduction example:

import jax
import jax.random as jrandom
import jax.numpy as jnp
import numpy as np

# compute matrix matrix product between 2 matrices of size (2, n) and (n, 2)
def compute_matrix_product(n, precision):
    key = jrandom.PRNGKey(seed=42)
    key_A, key_B = jrandom.split(key, 2)
    A = jrandom.uniform(key_A, shape=(2, n))
    B = jrandom.uniform(key_B, shape=(n, 2))
    return jnp.matmul(A, B, precision=precision)

cpu_device = jax.devices("cpu")[0]
gpu_device = jax.devices("gpu")[0]

n = 10
with jax.default_device(cpu_device):
    res_cpu = compute_matrix_product(n, precision="float32")
with jax.default_device(gpu_device):
    res_gpu_float32_precision = compute_matrix_product(n, precision="float32")
    res_gpu_tensorfloat32_precision = compute_matrix_product(n, precision="tensorfloat32")

# output difference between GPU solutions computed with different precision arguments
# expected output on Ampere GPU like the RTX 3090: finite number due to different precision between tensorfloat32 and float32
# actual output: difference on GPU with float32 and tensorfloat32 precision:  0.0
# so it looks like the same floating point format is used irrespective of the precision argument
print("difference on GPU with float32 and tensorfloat32 precision: ", np.mean(res_gpu_float32_precision - res_gpu_tensorfloat32_precision))

# check difference to CPU solutoin
# expected output:  small number on the level of float32 precision
# when run with NVIDIA_TF32_OVERRIDE=0, output is -5.9604645e-08
# actual output: difference to CPU solution:  0.00027871132
print("difference to CPU solution: ", np.mean(res_cpu - res_gpu_float32_precision))

What jax/jaxlib version are you using?

jax 0.4.1, jaxlib 0.4.1+cuda11.cudnn86

Which accelerator(s) are you using?

GPU, RTX 3090

Additional system info

Ubuntu 20

NVIDIA GPU info

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 455.32.00    Driver Version: 455.32.00    CUDA Version: 11.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Quadro P620         On   | 00000000:04:00.0 Off |                  N/A |
| 34%   31C    P8    N/A /  N/A |    219MiB /  1991MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  GeForce RTX 3090    On   | 00000000:09:00.0 Off |                  N/A |
|  0%   38C    P8    26W / 350W |   2129MiB / 24268MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
@jaschau jaschau added the bug Something isn't working label Jan 15, 2023
@jaschau jaschau changed the title Precision argument is not respected on Ampere GPUs like the RTX 3090 Precision argument for matmul is not respected on Ampere GPUs like the RTX 3090 Jan 15, 2023
@hawkinsp hawkinsp added the NVIDIA GPU Issues specific to NVIDIA GPUs label Jan 18, 2023
@jprabhas jprabhas self-assigned this Jan 18, 2023
@yhtang
Copy link
Collaborator

yhtang commented Feb 3, 2023

When I tried to run the example, I cannot reproduce it on 3090 but can reproduce it on A100:

root@474855ac51bf:/workspace# nvidia-smi -L
GPU 0: NVIDIA GeForce RTX 3090 (UUID: GPU-e95e1f9d-d019-4edc-464e-d24c104d47bd)
root@474855ac51bf:/workspace# python test.py 
difference on GPU with float32 and tensorfloat32 precision:  0.0
difference to CPU solution:  -5.9604645e-08
root@12ebabe78689:/workspace# nvidia-smi -L
GPU 0: NVIDIA A100-SXM4-80GB (UUID: GPU-21599b3e-aa53-2e4a-4d62-e5a6d9960330)
GPU 1: NVIDIA A100-SXM4-80GB (UUID: GPU-11812724-4b49-fca2-4d6a-c03ffe9cb5e7)
GPU 2: NVIDIA A100-SXM4-80GB (UUID: GPU-5cba534b-3791-d63e-409c-b2af5fa8e846)
GPU 3: NVIDIA A100-SXM4-80GB (UUID: GPU-043bd904-ac80-f9e8-7427-59f33bd503c0)
GPU 4: NVIDIA A100-SXM4-80GB (UUID: GPU-e9d548fd-a5de-a032-6092-5428bbe6fb99)
GPU 5: NVIDIA A100-SXM4-80GB (UUID: GPU-da6e8085-b9b1-f1ee-2404-ea2b28da7341)
GPU 6: NVIDIA A100-SXM4-80GB (UUID: GPU-de97d9d0-2228-8af5-3229-4c6dad92a540)
GPU 7: NVIDIA A100-SXM4-80GB (UUID: GPU-e91fdbef-eb74-e537-fe4f-923f878593f5)
root@12ebabe78689:/workspace# python test.py 
difference on GPU with float32 and tensorfloat32 precision:  0.0
difference to CPU solution:  0.00027871132

This is using the latest development branch.
JAX:

commit b8d6efe22fccfda9f3b03ab2e53e13814a4bf6a8 (grafted, HEAD)
Author: jax authors <[email protected]>
Date:   Thu Feb 2 23:25:39 2023 -0800

@jaschau
Copy link
Author

jaschau commented Feb 4, 2023

Thanks for giving this a try! The interesting thing is that the RTX 3090 reports zero difference between the float32 and the tensorfloat32 calculation on the GPU while the difference to the CPU solution is on the level of machine precision. So it seems as if the RTX 3090 in your case uses float32 precision no matter what. The A100 on the other hand seems to use tensorfloat32 irrespective of precision settings.
Could you check with jax 0.4.1 to see if there's a difference between the released and the dev version?

@pschuh
Copy link
Collaborator

pschuh commented Feb 6, 2023

This should be fixed in the upcoming jaxlib release.

@pschuh
Copy link
Collaborator

pschuh commented Feb 6, 2023

You might get somewhere with setting: XLA_FLAGS=--xla_gpu_enable_xla_runtime_executable=false until then?

@jprabhas
Copy link
Collaborator

jprabhas commented Feb 7, 2023

Thanks! It seems that XLA_FLAGS=--xla_gpu_enable_xla_runtime_executable=false does not resolve the problem on A100 GPUs with JAX 0.4.2. Was the fix done in the new runtime code?

@yhtang
Copy link
Collaborator

yhtang commented Feb 7, 2023

I can confirm this is fixed in our latest internal build:

root@80a23672e392:/opt# python --version
Python 3.11.1
root@80a23672e392:/opt# python -c 'import jax; print(jax.__version__)'
0.4.3
root@80a23672e392:/opt# (cd jax-source/ && git show --summary)
commit 219723c73817ffb884f4b87d7ed81be5a4354254 (HEAD -> main, origin/main, origin/HEAD)
Author: Roy Frostig <[email protected]>
Date:   Mon Feb 6 22:51:50 2023 -0800

    migrate internal dependencies from `jax.interpreters.ad` to `jax._src.interpreters.ad`
    
    ... in preparation for paring down `jax.interpreters.ad`'s exported symbols.
    
    Includes some import fixups along the way.
    
    PiperOrigin-RevId: 507684262

root@80a23672e392:/opt# (cd xla-source/ && git show --summary)
commit 19f20fd85f0a427c804ac8afda7110de3b416201 (HEAD -> master, origin/master, origin/HEAD)
Author: Justin Lebar <[email protected]>
Date:   Mon Feb 6 23:14:57 2023 -0800

    Add experimental `friendly_name` config option.
    
    Today when you use XLA autoclustering, the XLA module names have the form e.g.
    
      cluster_15111732669523428041_0__XlaCompiledKernel_true__XlaHasReferenceVars_false__XlaNumConstantArgs_0__XlaNumResourceArgs_0_.2393
    
    If you have a program with many TensorFlow nets, it is hard to tell which
    cluster_1234 belongs to which net.
    
    With this change, you can set a name in the TF session config and the name gets
    propagated down to the XLA cluster name:
    
      my_friendly_name_15111732669523428041_0__<snip>
    
    For now this only works for nets that enable XLA via autoclustering.  I'd like
    to do something similar for nets that use tf.function and inference-converter.
    
    PiperOrigin-RevId: 507687578

root@80a23672e392:/opt# python /root/test.py 
difference on GPU with float32 and tensorfloat32 precision:  0.00027871132
difference to CPU solution:  0.0
root@80a23672e392:/opt# nvidia-smi -L
GPU 0: NVIDIA A100-SXM4-80GB (UUID: GPU-d300aaf9-ef68-bc89-b4be-b3498b1e325a)
GPU 1: NVIDIA A100-SXM4-80GB (UUID: GPU-ce7de1f8-392a-2fff-a3be-375075aa4ed3)
GPU 2: NVIDIA A100-SXM4-80GB (UUID: GPU-59ac08d9-b32c-3898-263c-37d4f0949314)
GPU 3: NVIDIA A100-SXM4-80GB (UUID: GPU-39c717f4-d8cb-6c08-5e50-072e38abeb30)
GPU 4: NVIDIA A100-SXM4-80GB (UUID: GPU-c78dff09-1fc3-8575-0774-78ec78509097)
GPU 5: NVIDIA A100-SXM4-80GB (UUID: GPU-c41d5231-68d5-1d8a-11a8-bb9edf73d11c)
GPU 6: NVIDIA A100-SXM4-80GB (UUID: GPU-1f276a9c-8b31-f115-7b25-649f88b8dbda)
GPU 7: NVIDIA A100-SXM4-80GB (UUID: GPU-0c7b0987-3081-3387-cec8-76c702847270)

@yhtang
Copy link
Collaborator

yhtang commented Feb 7, 2023

Thanks! It seems that XLA_FLAGS=--xla_gpu_enable_xla_runtime_executable=false does not resolve the problem on A100 GPUs with JAX 0.4.2. Was the fix done in the new runtime code?

It's in the latest development branch, so a nightly build or the next release (0.4.3) should have it.

@jprabhas
Copy link
Collaborator

jprabhas commented Feb 8, 2023

I confirmed that this issue is fixed now. We can close this bug.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working NVIDIA GPU Issues specific to NVIDIA GPUs
Projects
None yet
Development

No branches or pull requests

5 participants