Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torch.compile] Fix RMSNorm + quant fusion in the non-cutlass-fp8 case, rename RedundantReshapesPass to NoopEliminationPass #10902

Merged
merged 5 commits into from
Feb 28, 2025

Conversation

ProExpertProg
Copy link
Contributor

@ProExpertProg ProExpertProg commented Dec 4, 2024

This PR fixes the fp8 case, when cutlass_mm is not available. It contains the following fixes:

  • Removes the padding for fp8 torch._scaled_mm in the torch.compile case, as branch specialization might not work correctly, and it makes fusion difficult.
  • Implements redundant slice and slice_scatter elimination, which is implemented in PyTorch but does not cover all cases. It renames the RedundantReshapesPass to NoopEliminationPass.
  • Minor custom pass improvements.

This PR is a pre-requisite PR to #10836, which enables torch.compile on AMD and uses the non-cutlass-fp8 path.

Copy link

github-actions bot commented Dec 4, 2024

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@ProExpertProg ProExpertProg force-pushed the luka/fp8-noncutlass-fix branch from b8ab496 to e5ded5c Compare December 4, 2024 20:18
Copy link

mergify bot commented Feb 15, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ProExpertProg.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Feb 15, 2025
@ProExpertProg ProExpertProg force-pushed the luka/fp8-noncutlass-fix branch from e5ded5c to a3cb530 Compare February 25, 2025 20:52
@mergify mergify bot removed the needs-rebase label Feb 25, 2025
@ProExpertProg ProExpertProg force-pushed the luka/fp8-noncutlass-fix branch 2 times, most recently from 65afeae to c22186b Compare February 26, 2025 19:08
@ProExpertProg ProExpertProg changed the title Fix for the padding in the non-cutlass-fp8 case [torch.compile] Fix RMSNorm + quant fusion in the non-cutlass-fp8 case Feb 26, 2025
Comment on lines 74 to 84
elif is_func(node, torch.ops.aten.slice.Tensor):
input, dim_index, start, end = node.args[:4]
input_shape = input.meta["val"].shape
i_dim = input_shape[dim_index]

if start == 0 and self.dims_equivalent(end, i_dim):
node.replace_all_uses_with(input)
graph.erase_node(node)
count += 1

elif is_func(node, torch.ops.aten.slice_scatter.default):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these always the right ops to use? e.g. is there a torch.ops.aten.slice.default or a torch.ops.aten.slice_scatter.Tensor?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't seen them so I am not sure - I just went off what I saw. The other overloads could be added easily if we ever see them in the graph

@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Feb 26, 2025
@ProExpertProg ProExpertProg force-pushed the luka/fp8-noncutlass-fix branch from 12e173e to 427bb9d Compare February 27, 2025 22:38
@@ -161,10 +162,14 @@ def apply_fp8_linear(
# Note: we pad the input because torch._scaled_mm is more performant
# for matrices with batch dimension > 16.
# This could change in the future.
# We also don't pad when using torch.compile,
# as it breaks with dynamic shapes.
config = get_current_vllm_config().compilation_config
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this cached? It could be expensive each forward call

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, in eager mode this will get called on every forward pass, but it will only happen once when compiled. In eager mode there isn't really a better way that's still correct - the only way is to check the config context. I don't think this getter is significant but I haven't measured it.

Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth Feb 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could pass in a allow_input_padding flag? and pass it in? I do think this is annoying though. I think it's woth it to do a quick check for performance regressions on a small model eager mode benchmark with cutlass_scaled_mm disabled?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we'd have to pass that flag through the whole call stack though so I don't think it's worth it. I'll run a small model.

@@ -161,10 +162,14 @@ def apply_fp8_linear(
# Note: we pad the input because torch._scaled_mm is more performant
# for matrices with batch dimension > 16.
# This could change in the future.
# We also don't pad when using torch.compile,
# as it breaks with dynamic shapes.
config = get_current_vllm_config().compilation_config
Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth Feb 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could pass in a allow_input_padding flag? and pass it in? I do think this is annoying though. I think it's woth it to do a quick check for performance regressions on a small model eager mode benchmark with cutlass_scaled_mm disabled?

Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good overall but I had a few minor comments

- rename cutlass_fp8 test flag
- rename noop pass
- improve some comments

Signed-off-by: luka <[email protected]>
Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the great work! LGTM assuming we don't see any performance regression

@ProExpertProg
Copy link
Contributor Author

Yep will post perf numbers once I have them, thanks!

@ProExpertProg ProExpertProg changed the title [torch.compile] Fix RMSNorm + quant fusion in the non-cutlass-fp8 case [torch.compile] Fix RMSNorm + quant fusion in the non-cutlass-fp8 case, rename RedundantReshapesPass to NoopEliminationPass Feb 28, 2025
@mgoin mgoin merged commit bd56c98 into vllm-project:main Feb 28, 2025
40 checks passed
johnnynunez pushed a commit to johnnynunez/vllm that referenced this pull request Mar 3, 2025
…e, rename RedundantReshapesPass to NoopEliminationPass (vllm-project#10902)

Signed-off-by: luka <[email protected]>
Signed-off-by: Johnny <[email protected]>
Akshat-Tripathi pushed a commit to krai/vllm that referenced this pull request Mar 3, 2025
…e, rename RedundantReshapesPass to NoopEliminationPass (vllm-project#10902)

Signed-off-by: luka <[email protected]>
lk-chen pushed a commit to lk-chen/vllm that referenced this pull request Mar 5, 2025
…e, rename RedundantReshapesPass to NoopEliminationPass (vllm-project#10902)

Signed-off-by: luka <[email protected]>
Signed-off-by: Linkun Chen <[email protected]>
tlrmchlsmth added a commit that referenced this pull request Mar 5, 2025
…-fp8 case, rename RedundantReshapesPass to NoopEliminationPass (#10902)"

This reverts commit bd56c98.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants