Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC]: Add support for IBM Spyre accelerator #9652

Open
1 of 6 tasks
tdoublep opened this issue Oct 24, 2024 · 5 comments
Open
1 of 6 tasks

[RFC]: Add support for IBM Spyre accelerator #9652

tdoublep opened this issue Oct 24, 2024 · 5 comments
Labels

Comments

@tdoublep
Copy link
Member

tdoublep commented Oct 24, 2024

Motivation.

IBM has recently announced its Spyre AI accelerator at Hot Chips 2024. This accelerator has been designed, in collaboration with IBM Research, to scale-up enterprise AI workloads running on IBM's mainframe systems (IBM Z), as well as on IBM's Power platform. Since IBM is building our inference stack on top of vLLM, we would like to enable support for IBM Spyre within the vLLM framework.

Spyre has been designed to fit seamlessly into the PyTorch ecosystem via torch.compile. Specifically, IBM Research has developed a new backend for torch.compile that will compile torch FX graphs for execution on the Spyre hardware. In this sense, we envision that Spyre support in vLLM can work in a similar way to how the TPU support is working today (e.g., see here).

Today, there are two key limitations that affect this integration and need to be worked around. Specifically:

  1. Spyre only supports execution of the modeling code from IBM's open-source Foundation Model Stack (fms). This is only a temporary limitation, the end-goal is that, via torch.compile, we can also run the native vLLM modeling code on Spyre. We hope that recent efforts to make vLLM models torch compilable will significantly accelerate this effort.
  2. Spyre does not currently support paged attention or continuous batching. This means that, after prefilling a batch, we need to keep decoding until all of the sequences within that batch are finished. Our team is working to remove this limitation in the near future.

Proposed Change.

In this RFC, we propose the following sequence of PRs to enable IBM Spyre support in vLLM:

  • P1: Add support for single Spyre card via new SpyreExecutor class
  • P2: Changes to scheduling algorithm to disable continuous batching.
  • P3: Enable TP execution across multiple Spyre cards via MultiprocessingSpyreExecutor class.
  • P4: Enable paged attention and continuous batching for Spyre.
  • P5: Enable vLLM modeling code to run on Spyre.

While much of the work here (P1, P2, P3) has already been completed in an private fork, we plan to upstream the changes as a sequence of smaller PRs, to make the changes easier to review. Below we will discuss the planned changes from each PR step.

P1: Add support for single Spyre card via new SpyreExecutor class

We will introduce a set of classes, inheriting from the core vLLM classes, that will enable execution on a single Spyre device. Architecturally, this will look very similar to the equivalent classes that were introduced for running on AWS Inferentia (e.g., NeuronExecutor, NeuronWorker, NeuronModelRunner, NeuronCausalLM). In a similar way to how the NeuronModelRunner use the modeling code from the transformers_neuronx package, these new classes will execute the modeling code from IBM's fms package.

In the diagram below, we compare the proposed Spyre classes, with the corresponding classes that already exist for the AWS Inferentia support:

image

Since Spyre works via torch.compile, to ensure that compilation does not occur on the critical path (e.g., serving user requests), we need to ensure that all compilation gets triggered at init time. This PR will also introduce a routine for warming up the inference server when using the Spyre, triggering compilation of all required shapes (e.g., prompt length, number of output tokens, batch size). We will write code to ensure that batches get padded to one of the compiled shapes before execution. This behaviour is akin to what happens today in vLLM for CUDA graphs, and presumably something like this warmup will also be needed once vLLM starts using torch.compile more extensively. This could be one area to explore commonality with others parts of the codebase.

Testing: While testing on the real hardware can only be performed internally for now, we can test the vast majority of the integration on the CPU by either (a) running in eager mode or (b) by using torch compile with the inductor backend. Thus, in this PR we will also add a set of unit and integration tests to verify that everything behaves as expected. The tests will focus on the offline mode, since changes to the scheduling algorithm are needed to support online mode (see P2). We will also add a Dockerfile.spyre containing all necessary dependencies (e.g., FMS) in which the tests can be executed. Whether we could have these tests running as part of vLLM's CI/CD is something we would like to discuss.

P2: Changes to scheduling algorithm to disable continuous batching.

We need to introduce a few changes to the scheduling algorithm to workaround the lack of continuous batching support. Specifically:

  1. We must not schedule another prefill until all decodes in the running batch are finished (one line change).
  2. We need to introduce some logic to decide how to batch together request based on the prompt lengths and max output tokens, in order to best fit the shapes that have been compiled on Spyre during the warmup phase.

These changes must be conditional and not affect the behaviour of the scheduler on existing supported devices. They could either be applied within the scheduler itself (e.g., by checking is_spyre()) or we could try to "plug in" an alternate scheduler design? This is one of the design choices we would like some feedback on.

Testing: As part of this PR, we will also introduce tests to cover the integration with the MQLLMEngine and online operation.

P3: Enable TP execution across multiple Spyre cards via MultiprocessingSpyreExecutor class.

We have found that the MultiprocessingGPUExecutor can be easily adapted into a MultiprocessingSpyreExecutor to enable TP execution across multiple Spyre devices in parallel. However, to reduce code duplication we propose refactoring the common code between MultiprocessingGPUExecutor and MultiprocessingSpyreExecutor into a common parent class MultiprocessingExecutor. By inheriting from MultiprocessingExecutor and a corresponding mixin class (e.g., GPUExecutor or SpyreExecutor) it should be possible to achieve the desired behaviour with very little device-specific code. Note that something along these lines already exists for the MultiprocessingXPUExecutor (e.g., see here), but the design proposed below would give more flexibility for device-specific specialization, and would also easily allow us to create multi-processing executors for all support devices if we want.

The architecture would look something like this:

image

Testing: We will add tests to verify that the MultiprocessingSpyreExecutor behaves as expected for tensor parallel execution when running using eager or usinginductor backend on CPU. Internally, we will run these tests against the real hardware.

P4: Enable paged attention and continuous batching for Spyre.

TBD

P5: Enable vLLM modeling code to run on Spyre.

TBD

Feedback Period.

2 weeks

CC List.

@njhill @simon-mo @youkaichao @zhuohan123 @comaniac @WoosukKwon @Yard1

Please cc anyone else as you see fit!

Any Other Things.

No response

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
@tdoublep tdoublep added the RFC label Oct 24, 2024
@tlrmchlsmth
Copy link
Collaborator

Hi @tdoublep, thanks for the thorough writeup! I had some torch.compile related questions

Specifically, IBM Research has developed a new backend for torch.compile that will compile torch FX graphs for execution on the Spyre hardware

Is this a torch dynamo backend? Or does it integrate via Inductor?

This PR will also introduce a routine for warming up the inference server when using the Spyre, triggering compilation of all required shapes (e.g., prompt length, number of output tokens, batch size). We will write code to ensure that batches get padded to one of the compiled shapes before execution. This behaviour is akin to what happens today in vLLM for CUDA graphs, and presumably something like this warmup will also be needed once vLLM starts using torch.compile more extensively.

Agreed, something like this would be useful for GPUs. We want to be able to write custom inductor passes where we only apply certain optimizations depending on the number of tokens we're working with. (cc @bnellnm who is working on an inductor pass like that)

BTW: It sounds like the Spyre might require all shapes to be static, is that true?

@tdoublep
Copy link
Member Author

Thanks for reading @tlrmchlsmth

Is this a torch dynamo backend? Or does it integrate via Inductor?

The former, it integrates at the level of FX graphs.

BTW: It sounds like the Spyre might require all shapes to be static, is that true?

Right now each compile has a static prompt length and batch size, but the number of output tokens is dynamic to an extent (we still need to define some maximum). We are able to compile multiple graphs to support different cases though, and we have some logic to pad to the nearest reasonable shape (again, sort of like how CUDA graphs works today iirc).

@tdoublep
Copy link
Member Author

Update from our side: we have now open-sourced the code for Spyre support on IBM's fork of vLLM:
https://github.com/IBM/vllm

@tlrmchlsmth
Copy link
Collaborator

Nice, looks like all the needed changes are contained in this PR IBM/vllm#56?

@tdoublep
Copy link
Member Author

@tlrmchlsmth Yes, those are the changes compared with a commit from last month. We will continue developing on the open repo from here onwards, and will be pulling in changes from upstream frequently.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants