Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1] Revert uncache_blocks and support recaching full blocks #12415

Merged
merged 2 commits into from
Feb 3, 2025

Conversation

comaniac
Copy link
Collaborator

@comaniac comaniac commented Jan 24, 2025

Revert #12333 as we realized that it's unsafe to uncache a block because it may be used by more than one requests. On the other hand, this PR supports _cache_full_blocks to deal with the already cached blocks, so that speculative decoding can use it in the following way:

# We must have at least these many tokens that have already been computed.
# Note that this doesn't have to be super accurate because this is just an optimization that avoid
# us always checking whether the first N blocks are already cached in every step.
min_num_last_step_computed_tokens = request.num_computed_tokens - k
# So we must have these blocks cached already.
min_num_last_step_computed_full_blocks = num_last_step_computed_tokens // self.block_size
# And we must have these blocks to be cached after appending `num_tokens` (e.g., bonus token, etc).
num_full_blocks_after_append = (request.num_computed_tokens + num_tokens_wo_spec_tokens) // self.block_size

# Now `new_full_blocks` may start with an already cached block,
# but the change in this PR can handle this case.
new_full_blocks = req_blocks[num_computed_full_blocks:num_full_blocks_after_append]
if new_full_blocks:
    self._cache_full_blocks(
        request=request,
        blk_start_idx=num_computed_full_blocks,
        full_blocks=new_full_blocks,
        prev_block=req_blocks[num_computed_full_blocks - 1]
        if num_computed_full_blocks >= 1 else None,
    )

cc @LiuXiaoxuanPKU

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

Signed-off-by: Cody Yu <[email protected]>
@comaniac comaniac added the ready ONLY add when PR is ready to merge/full CI is needed label Jan 24, 2025
@comaniac comaniac enabled auto-merge (squash) February 3, 2025 18:32
@mergify mergify bot added the v1 label Feb 3, 2025
Copy link
Collaborator

@LiuXiaoxuanPKU LiuXiaoxuanPKU left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the API! LGTM.

@simon-mo simon-mo disabled auto-merge February 3, 2025 23:04
@simon-mo simon-mo merged commit 5095e96 into vllm-project:main Feb 3, 2025
31 of 34 checks passed
fxmarty-amd pushed a commit to fxmarty-amd/vllm that referenced this pull request Feb 7, 2025
ShangmingCai pushed a commit to ShangmingCai/vllm that referenced this pull request Feb 10, 2025
@comaniac comaniac deleted the re-cache branch February 13, 2025 23:40
panf2333 pushed a commit to yottalabsai/vllm that referenced this pull request Feb 18, 2025
kerthcet pushed a commit to kerthcet/vllm that referenced this pull request Feb 21, 2025
lk-chen pushed a commit to lk-chen/vllm that referenced this pull request Mar 5, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants