Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Question] How to use Relax attention op in mlc-llm workflow #690

Closed
ylc2001 opened this issue Aug 8, 2023 · 2 comments
Closed

[Question] How to use Relax attention op in mlc-llm workflow #690

ylc2001 opened this issue Aug 8, 2023 · 2 comments
Labels
question Question about the usage

Comments

@ylc2001
Copy link

ylc2001 commented Aug 8, 2023

❓ General Questions

Recently flash attn v2 is supported by tvm and I am trying to use it to optimize performance of mlc-llm on GPU.

Currently in mlc_llm/relax_model/llama.py, the attention computation is done using many basic operations instead of Relax's attention. I tried to use relax attention op (see this code) but it didn't quite work out.
The performance of attention is very low and seems it's not even running on GPU:

Time elapsed: encoding 4.145921468734741 seconds, decoding 0.03513669967651367 secs
Profiling...
======================= Encoding Profiling =======================
Name                                              Time (ms)   Count   Total time (ms)   Pct (%)   Memory (MB)     Bandwidth (GB/s)  Shape
attention                                         126.0235    32      4032.7533         97.21     100.38          0.7779            (1, 1001, 32, 128), (1, 1001, 32, 128), (1, 1001, 32, 128), (1, 1001, 32, 128), (32, 1001, 1001), (32, 1001), (32, 1001), (32, 1001, 128)
fused_fused_decode4_NT_matmul2                    1.3956      32      44.6596           1.08      98.23           68.7348           (22016, 512), (22016, 128), (1, 1001, 4096), (1, 1001, 22016)
fused_fused_decode5_fused_NT_matmul3_add          0.8974      32      28.7167           0.69      60.85           66.2129           (4096, 1376), (4096, 344), (1, 1001, 11008), (1, 1001, 4096), (1, 1001, 4096)
fused_fused_decode2_NT_matmul                     0.8103      32      25.9304           0.63      58.28           70.2376           (12288, 512), (12288, 128), (1, 1001, 4096), (1, 1001, 12288)
fused_fused_decode3_fused_NT_matmul1_add          0.3381      32      10.8188           0.26      32.46           93.7632           (4096, 512), (4096, 128), (1, 1001, 4096), (1, 1001, 4096), (1, 1001, 4096)
fused_split1_silu_multiply                        0.0685      32      2.1920            0.05      63.05           898.8994          (1, 1001, 22016), (1, 1001, 11008)
split                                             0.0622      32      1.9911            0.05      46.92           736.4342          (1, 1001, 12288), (1, 1001, 4096), (1, 1001, 4096), (1, 1001, 4096)
rms_norm                                          0.0203      65      1.3225            0.03      15.65           751.0578          (1, 1001, 4096), (4096,), (1, 1001, 4096)
fused_fused_decode1_fused_NT_matmul4_cast         0.0673      1       0.0673            0.00      70.44           1021.9943         (32000, 512), (32000, 128), (1, 1, 4096), (1, 1, 32000)
fused_fused_decode1_take                          0.0149      1       0.0149            0.00      78.14           5117.9188         (32000, 512), (32000, 128), (1001,), (1001, 4096)
slice                                             0.0027      1       0.0027            0.00      7.83            2878.7277         (1, 1001, 4096), (1, 1, 4096)
Total time: 4148.4693 ms

======================= Decoding Profiling =======================
Name                                              Time (ms)   Count   Total time (ms)   Pct (%)   Memory (MB)     Bandwidth (GB/s)  Shape
attention1                                        0.8677      32      27.7673           85.70     15.74           17.7153           (1, 1, 32, 128), (1, 1002, 32, 128), (1, 1002, 32, 128), (1, 1, 32, 128), (32, 1, 1002), (32, 1), (32, 1), (32, 1, 128)
fused_fused_decode4_NT_matmul7                    0.0481      32      1.5406            4.76      48.42           982.2368          (22016, 512), (22016, 128), (1, 1, 4096), (1, 1, 22016)
fused_fused_decode5_fused_NT_matmul8_add1         0.0340      32      1.0884            3.36      24.22           695.4920          (4096, 1376), (4096, 344), (1, 1, 11008), (1, 1, 4096), (1, 1, 4096)
fused_fused_decode2_NT_matmul5                    0.0288      32      0.9223            2.85      27.03           915.8815          (12288, 512), (12288, 128), (1, 1, 4096), (1, 1, 12288)
rms_norm1                                         0.0080      65      0.5225            1.61      0.02            2.8474            (1, 1, 4096), (4096,), (1, 1, 4096)
fused_fused_decode3_fused_NT_matmul6_add1         0.0124      32      0.3965            1.22      9.02            711.1918          (4096, 512), (4096, 128), (1, 1, 4096), (1, 1, 4096), (1, 1, 4096)
fused_split2_silu1_multiply1                      0.0029      32      0.0921            0.28      0.06            21.3774           (1, 1, 22016), (1, 1, 11008)
fused_fused_decode1_fused_NT_matmul4_cast         0.0674      1       0.0674            0.21      70.44           1020.3387         (32000, 512), (32000, 128), (1, 1, 4096), (1, 1, 32000)
fused_fused_decode1_take1                         0.0028      1       0.0028            0.01      70.32           24505.3005        (32000, 512), (32000, 128), (1,), (1, 4096)
Total time: 32.4000 ms

There's no proper documentations about Relax to refer to. Could anyone tell me what's the proper usage of R.nn.attention and why mlc-llm is not using it to implement the model?

@ylc2001 ylc2001 added the question Question about the usage label Aug 8, 2023
@masahi
Copy link
Contributor

masahi commented Aug 9, 2023

See rewrite_attention function in #651. Flash v2 actually is not fast for a single-query workload (Dao-AILab/flash-attention#427 (comment)), so for the decoder we use the xformer kernel.

@ylc2001
Copy link
Author

ylc2001 commented Aug 9, 2023

thx a lot

@ylc2001 ylc2001 closed this as completed Aug 9, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Question about the usage
Projects
None yet
Development

No branches or pull requests

2 participants