-
Notifications
You must be signed in to change notification settings - Fork 47
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Hardware][Ascend]MLA for deepseek #88
Conversation
vllm_ascend/attention.py
Outdated
kv_b_proj_weight = self.kv_b_proj.weight.reshape(self.num_heads, | ||
self.qk_nope_head_dim + self.v_head_dim, | ||
self.kv_lora_rank) | ||
w_kc = kv_b_proj_weight[:, :self.qk_nope_head_dim, :].contiguous() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should modify model loader to enable this on intialize stage
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can workaround this by set this weight as attribute during runtime, this way only do slice + contiguous one time compared with this version
vllm_ascend/attention.py
Outdated
compressType=0, calcType=0, scaleType=0, quantType=0, | ||
inputLayout=0, outDataType=-1, attnOut=attn_output) | ||
attn_output_t = torch_npu.npu_transpose(attn_output, (1, 0, 2), require_contiguous=True) | ||
attn_output_t = torch_npu.npu_bmmV2(attn_output_t, w_vc, []) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.bmm
can do the same maybe.
kv = self.kv_b_proj(kv_c_normed)[0].view(num_tokens, kv_heads_num, -1) | ||
k_nope, value = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) | ||
k_cache = torch.cat([kv_c_normed.view(num_tokens, self.num_kv_heads, -1), k_pe], dim=2) | ||
k_pe = k_pe.repeat(1, self.num_heads, 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you test torch.expand
here? which dose not touch global memory compared with repeat
.
Signed-off-by: YHT <[email protected]>
<!-- Thanks for sending a pull request! BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html --> ### What this PR does / why we need it? <!-- - Please clarify what changes you are proposing. The purpose of this section is to outline the changes and how this PR fixes the issue. If possible, please consider writing useful notes for better and faster reviews in your PR. - Please clarify why the changes are needed. For instance, the use case and bug description. - Fixes # --> To adapt to the MLA structure of vLLM DeepSeek on Ascend hardware, write the AscendMLAAttentionBackendImpl class. ### Does this PR introduce _any_ user-facing change? <!-- Note that it means *any* user-facing change including all aspects such as API, interface or other behavior changes. Documentation-only updates are not considered user-facing changes. --> Users can choose to set VLLM_MLA_DISABLE to 1 or 0 to disable or enable MLA. ### How was this patch tested? <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> Signed-off-by: YHT <[email protected]> Co-authored-by: YHT <[email protected]> Signed-off-by: angazenn <[email protected]>
What this PR does / why we need it?
To adapt to the MLA structure of vLLM DeepSeek on Ascend hardware, write the AscendMLAAttentionBackendImpl class.
Does this PR introduce any user-facing change?
Users can choose to set VLLM_MLA_DISABLE to 1 or 0 to disable or enable MLA.
How was this patch tested?