Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Precision control has no effect #14157

Closed
yhtang opened this issue Jan 25, 2023 · 5 comments
Closed

Precision control has no effect #14157

yhtang opened this issue Jan 25, 2023 · 5 comments
Assignees
Labels
bug Something isn't working NVIDIA GPU Issues specific to NVIDIA GPUs XLA

Comments

@yhtang
Copy link
Collaborator

yhtang commented Jan 25, 2023

Possibly related to #13038.
Possible duplicate of #14022.

Description

Neither the jax.default_matmul_precision context manager nor the precision= argument to jax.numpy.matmul could change the precision for matrix multiplication on A100, and the result cannot reach a similar precision to that of numpy using float32.

import numpy as np
import jax
import jaxlib
import jax.numpy as jnp
from jax.lax import Precision
A = np.array([[1.1111111, 2.2222222], [3.3333333, 4.4444444]], dtype=np.float32)
Aj = jnp.array(A)
# fp32 reference
A @ A.T
>array([[ 6.1728387, 13.580245 ],
>       [13.580245 , 30.864195 ]], dtype=float32)
Aj @ Aj.T
>  Array([[ 6.175251, 13.585552],
>        [13.585552, 30.876255]], dtype=float32)
with jax.default_matmul_precision("float32"):
    print(Aj @ Aj.T)
>    [[ 6.175251 13.585552]
>     [13.585552 30.876255]]
jnp.matmul(Aj, Aj.T, precision=Precision.DEFAULT)
>    Array([[ 6.175251, 13.585552],
>           [13.585552, 30.876255]], dtype=float32)
jnp.matmul(Aj, Aj.T, precision=Precision.HIGH)
>    Array([[ 6.175251, 13.585552],
>           [13.585552, 30.876255]], dtype=float32)
jnp.matmul(Aj, Aj.T, precision=Precision.HIGHEST)
>    Array([[ 6.175251, 13.585552],
>           [13.585552, 30.876255]], dtype=float32)

What jax/jaxlib version are you using?

Development branch

Which accelerator(s) are you using?

GPU

Additional system info

Python 3.11.1, Ubuntu 20.04

NVIDIA GPU info

    Wed Jan 25 19:41:52 2023       
    +-----------------------------------------------------------------------------+
    | NVIDIA-SMI 470.141.10   Driver Version: 470.141.10   CUDA Version: 11.8     |
    |-------------------------------+----------------------+----------------------+
    | GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
    | Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
    |                               |                      |               MIG M. |
    |===============================+======================+======================|
    |   0  NVIDIA A100-SXM...  On   | 00000000:07:00.0 Off |                    0 |
    | N/A   31C    P0    84W / 400W |      0MiB / 81251MiB |      0%      Default |
    |                               |                      |             Disabled |
    +-------------------------------+----------------------+----------------------+
    |   1  NVIDIA A100-SXM...  On   | 00000000:0F:00.0 Off |                    0 |
    | N/A   28C    P0    63W / 400W |      0MiB / 81251MiB |      0%      Default |
    |                               |                      |             Disabled |
    +-------------------------------+----------------------+----------------------+
    |   2  NVIDIA A100-SXM...  On   | 00000000:47:00.0 Off |                    0 |
    | N/A   27C    P0    60W / 400W |      0MiB / 81251MiB |      0%      Default |
    |                               |                      |             Disabled |
    +-------------------------------+----------------------+----------------------+
    |   3  NVIDIA A100-SXM...  On   | 00000000:4E:00.0 Off |                    0 |
    | N/A   26C    P0    62W / 400W |      0MiB / 81251MiB |      0%      Default |
    |                               |                      |             Disabled |
    +-------------------------------+----------------------+----------------------+
    |   4  NVIDIA A100-SXM...  On   | 00000000:87:00.0 Off |                    0 |
    | N/A   35C    P0    63W / 400W |      0MiB / 81251MiB |      0%      Default |
    |                               |                      |             Disabled |
    +-------------------------------+----------------------+----------------------+
    |   5  NVIDIA A100-SXM...  On   | 00000000:90:00.0 Off |                    0 |
    | N/A   32C    P0    61W / 400W |      0MiB / 81251MiB |      0%      Default |
    |                               |                      |             Disabled |
    +-------------------------------+----------------------+----------------------+
    |   6  NVIDIA A100-SXM...  On   | 00000000:B7:00.0 Off |                    0 |
    | N/A   32C    P0    61W / 400W |      0MiB / 81251MiB |      0%      Default |
    |                               |                      |             Disabled |
    +-------------------------------+----------------------+----------------------+
    |   7  NVIDIA A100-SXM...  On   | 00000000:BD:00.0 Off |                    0 |
    | N/A   31C    P0    61W / 400W |      0MiB / 81251MiB |      0%      Default |
    |                               |                      |             Disabled |
    +-------------------------------+----------------------+----------------------+
                                                                                   
    +-----------------------------------------------------------------------------+
    | Processes:                                                                  |
    |  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
    |        ID   ID                                                   Usage      |
    |=============================================================================|
    |  No running processes found                                                 |
    +-----------------------------------------------------------------------------+

@yhtang yhtang added bug Something isn't working NVIDIA GPU Issues specific to NVIDIA GPUs labels Jan 25, 2023
@yhtang
Copy link
Collaborator Author

yhtang commented Jan 25, 2023

viz @nouiz @mjsML

@hawkinsp
Copy link
Collaborator

Can you dump the HLO for a high precision matmul? XLA_FLAGS=--xla_dump_to=/tmp/somewhere. I'm wondering if JAX is losing the precision annotation or if XLA isn't acting on it correctly.

@yhtang
Copy link
Collaborator Author

yhtang commented Jan 25, 2023

Can you dump the HLO for a high precision matmul? XLA_FLAGS=--xla_dump_to=/tmp/somewhere. I'm wondering if JAX is losing the precision annotation or if XLA isn't acting on it correctly.

So the annotation is apparently preserved:
image

# precision = DEFAULT
HloModule jit_f, is_scheduled=true, entry_computation_layout={()->f32[2,2]{1,0}}, allow_spmd_sharding_propagation_to_output=true

ENTRY %main.4 () -> f32[2,2] {
  %constant_3 = f32[2,2]{1,0} constant({ { 6.17283869, 13.580246 }, { 13.580246, 30.8641949 } }), metadata={op_name="jit(f)/jit(main)/dot_general[dimension_numbers=(((1,), (0,)), ((), ())) precision=(<Precision.DEFAULT: 0>, <Precision.DEFAULT: 0>) preferred_element_type=None]" source_file="/tmp/ipykernel_3552/16315587.py" source_line=3}
  ROOT %copy = f32[2,2]{1,0} copy(f32[2,2]{1,0} %constant_3)
}

# precision = HIGH
HloModule jit_f, is_scheduled=true, entry_computation_layout={()->f32[2,2]{1,0}}, allow_spmd_sharding_propagation_to_output=true

ENTRY %main.4 () -> f32[2,2] {
  %constant_3 = f32[2,2]{1,0} constant({ { 6.17283869, 13.580246 }, { 13.580246, 30.8641949 } }), metadata={op_name="jit(f)/jit(main)/dot_general[dimension_numbers=(((1,), (0,)), ((), ())) precision=(<Precision.HIGH: 1>, <Precision.HIGH: 1>) preferred_element_type=None]" source_file="/tmp/ipykernel_3552/823144715.py" source_line=3}
  ROOT %copy = f32[2,2]{1,0} copy(f32[2,2]{1,0} %constant_3)
}

# precision = HIGHEST
HloModule jit_f, is_scheduled=true, entry_computation_layout={()->f32[2,2]{1,0}}, allow_spmd_sharding_propagation_to_output=true

ENTRY %main.4 () -> f32[2,2] {
  %constant_3 = f32[2,2]{1,0} constant({ { 6.17283869, 13.580246 }, { 13.580246, 30.8641949 } }), metadata={op_name="jit(f)/jit(main)/dot_general[dimension_numbers=(((1,), (0,)), ((), ())) precision=(<Precision.HIGHEST: 2>, <Precision.HIGHEST: 2>) preferred_element_type=None]" source_file="/tmp/ipykernel_3552/2439358024.py" source_line=3}
  ROOT %copy = f32[2,2]{1,0} copy(f32[2,2]{1,0} %constant_3)
}

@hawkinsp
Copy link
Collaborator

Then the annotation has been lost somewhere during lowering in XLA and this is an XLA bug.

@yhtang
Copy link
Collaborator Author

yhtang commented Feb 7, 2023

This is apparently solved according to this thread. Thanks everyone!

@yhtang yhtang closed this as completed Feb 7, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working NVIDIA GPU Issues specific to NVIDIA GPUs XLA
Projects
None yet
Development

No branches or pull requests

3 participants