Skip to content

Commit

Permalink
feat:Optimize qwen2-vl to reduce cudaMemcpyAsync
Browse files Browse the repository at this point in the history
  • Loading branch information
cynthieye committed Mar 6, 2025
1 parent 9f1710f commit 3f21ec2
Showing 1 changed file with 34 additions and 9 deletions.
43 changes: 34 additions & 9 deletions vllm/model_executor/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property, partial
from typing import (Any, Callable, Literal, Optional, Set, Tuple, TypedDict,
from typing import (Any, Callable, List, Literal, Optional, Set, Tuple, TypedDict,
Union)

import torch
Expand Down Expand Up @@ -307,6 +307,8 @@ def forward(
x: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor,
max_seqlen: int = None,
seqlens: List = None,
) -> torch.Tensor:

# [s, b, c] --> [s, b, 3 * head * head_dim]
Expand All @@ -329,7 +331,6 @@ def forward(

q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])

max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
output = flash_attn_varlen_func(q,
k,
v,
Expand Down Expand Up @@ -365,7 +366,6 @@ def forward(
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalMask

seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens,
kv_seqlen=None,
device=q.device)
Expand Down Expand Up @@ -409,11 +409,22 @@ def __init__(
quant_config=quant_config,
prefix=f"{prefix}.mlp")

def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor) -> torch.Tensor:
x = x + self.attn(self.norm1(x),
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb)
def forward(
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor,
max_seqlen: int,
seqlens: List,
) -> torch.Tensor:
x = x + self.attn(
self.norm1(x),
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
max_seqlen=max_seqlen,
seqlens=seqlens,
)

x = x + self.mlp(self.norm2(x))
return x

Expand Down Expand Up @@ -570,6 +581,7 @@ def __init__(
quant_config=quant_config,
prefix=f"{prefix}.merger",
)
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)

@property
def dtype(self) -> torch.dtype:
Expand Down Expand Up @@ -624,8 +636,21 @@ def forward(

# transformers
x = x.unsqueeze(1)

max_seqlen = None
seqlens = None
if self.attn_backend == _Backend.FLASH_ATTN:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
elif self.attn_backend == _Backend.XFORMERS:
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
for blk in self.blocks:
x = blk(x, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)
x = blk(
x,
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
max_seqlen=max_seqlen,
seqlens=seqlens,
)

# adapter
x = self.merger(x)
Expand Down

0 comments on commit 3f21ec2

Please sign in to comment.