-
-
Notifications
You must be signed in to change notification settings - Fork 6.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[torch.compile] Fix RMSNorm + quant fusion in the non-cutlass-fp8 case, rename RedundantReshapesPass to NoopEliminationPass #10902
Conversation
👋 Hi! Thank you for contributing to the vLLM project. Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can do one of these:
🚀 |
b8ab496
to
e5ded5c
Compare
This pull request has merge conflicts that must be resolved before it can be |
e5ded5c
to
a3cb530
Compare
65afeae
to
c22186b
Compare
vllm/compilation/reshapes.py
Outdated
elif is_func(node, torch.ops.aten.slice.Tensor): | ||
input, dim_index, start, end = node.args[:4] | ||
input_shape = input.meta["val"].shape | ||
i_dim = input_shape[dim_index] | ||
|
||
if start == 0 and self.dims_equivalent(end, i_dim): | ||
node.replace_all_uses_with(input) | ||
graph.erase_node(node) | ||
count += 1 | ||
|
||
elif is_func(node, torch.ops.aten.slice_scatter.default): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these always the right ops to use? e.g. is there a torch.ops.aten.slice.default or a torch.ops.aten.slice_scatter.Tensor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't seen them so I am not sure - I just went off what I saw. The other overloads could be added easily if we ever see them in the graph
Signed-off-by: luka <[email protected]>
Signed-off-by: luka <[email protected]>
Signed-off-by: luka <[email protected]>
Signed-off-by: luka <[email protected]>
12e173e
to
427bb9d
Compare
@@ -161,10 +162,14 @@ def apply_fp8_linear( | |||
# Note: we pad the input because torch._scaled_mm is more performant | |||
# for matrices with batch dimension > 16. | |||
# This could change in the future. | |||
# We also don't pad when using torch.compile, | |||
# as it breaks with dynamic shapes. | |||
config = get_current_vllm_config().compilation_config |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this cached? It could be expensive each forward call
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, in eager mode this will get called on every forward pass, but it will only happen once when compiled. In eager mode there isn't really a better way that's still correct - the only way is to check the config context. I don't think this getter is significant but I haven't measured it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could pass in a allow_input_padding
flag? and pass it in? I do think this is annoying though. I think it's woth it to do a quick check for performance regressions on a small model eager mode benchmark with cutlass_scaled_mm
disabled?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we'd have to pass that flag through the whole call stack though so I don't think it's worth it. I'll run a small model.
@@ -161,10 +162,14 @@ def apply_fp8_linear( | |||
# Note: we pad the input because torch._scaled_mm is more performant | |||
# for matrices with batch dimension > 16. | |||
# This could change in the future. | |||
# We also don't pad when using torch.compile, | |||
# as it breaks with dynamic shapes. | |||
config = get_current_vllm_config().compilation_config |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could pass in a allow_input_padding
flag? and pass it in? I do think this is annoying though. I think it's woth it to do a quick check for performance regressions on a small model eager mode benchmark with cutlass_scaled_mm
disabled?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good overall but I had a few minor comments
- rename cutlass_fp8 test flag - rename noop pass - improve some comments Signed-off-by: luka <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the great work! LGTM assuming we don't see any performance regression
Yep will post perf numbers once I have them, thanks! |
…e, rename RedundantReshapesPass to NoopEliminationPass (vllm-project#10902) Signed-off-by: luka <[email protected]> Signed-off-by: Johnny <[email protected]>
…e, rename RedundantReshapesPass to NoopEliminationPass (vllm-project#10902) Signed-off-by: luka <[email protected]>
…e, rename RedundantReshapesPass to NoopEliminationPass (vllm-project#10902) Signed-off-by: luka <[email protected]> Signed-off-by: Linkun Chen <[email protected]>
This PR fixes the
fp8
case, whencutlass_mm
is not available. It contains the following fixes:fp8
torch._scaled_mm
in thetorch.compile
case, as branch specialization might not work correctly, and it makes fusion difficult.slice
andslice_scatter
elimination, which is implemented in PyTorch but does not cover all cases. It renames theRedundantReshapesPass
toNoopEliminationPass
.This PR is a pre-requisite PR to #10836, which enables
torch.compile
on AMD and uses the non-cutlass-fp8 path.