Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request] Support input embedding in LLM.generate() #416

Open
KimmiShi opened this issue Jul 10, 2023 · 17 comments · May be fixed by #6869 or #11684
Open

[Feature Request] Support input embedding in LLM.generate() #416

KimmiShi opened this issue Jul 10, 2023 · 17 comments · May be fixed by #6869 or #11684
Labels
feature request New feature or request

Comments

@KimmiShi
Copy link

KimmiShi commented Jul 10, 2023

Hi, I am using llm as part of a multimodal model, so the model needs to pass input embedding tensor directly to generate, and also need to access the language model's embed_tokens member to fist calculate the embedding, and then processed, finnaly send to generate, demo in the following code :

        inputs_embeds = self.language_model.get_input_embeddings()(input_ids)

        prefix_embeds = inputs_embeds[:, :self.offset, :]
        postfix_embeds = inputs_embeds[:, self.offset:, :]
        inputs_embeds = torch.cat([prefix_embeds, language_model_inputs, postfix_embeds], dim=1)

        .....
        attention_mask = torch.cat([prefix_mask, vision_mask, postfix_mask], dim=-1)

        outputs = self.language_model.generate(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            generation_config=generation_config,
            **generate_kwargs,
        )

I read the vllm code, and it seems that I need to add two interfaces in vllm, one is LLM.get_input_embeddings, another one is LLM.generate(inputs_embeds=inputs_embeds, ...)

Do you think this will work? And would you consider support this feature?

@KimmiShi KimmiShi changed the title Accept input embedding in generate(..) Accept input embedding in LLM.generate() Jul 10, 2023
@KimmiShi
Copy link
Author

It seems that worker._prepare_inputs method need to be modified to support embedding tensor input, can you support this feature?

@KimmiShi KimmiShi changed the title Accept input embedding in LLM.generate() [Feature Request] Support input embedding in LLM.generate() Jul 11, 2023
@WoosukKwon WoosukKwon added the feature request New feature or request label Jul 13, 2023
@hangzhang-nlp
Copy link

Awesome work!! And I have the same need.

@zacharyblank
Copy link

Has there been any progress on this? I am looking to achieve something very similar: essentially i need to be able to pass in a previously calculated embedding as to not have to recalculate it as part of a common prompt. I have somewhere between 4 - 12k tokens that are currently being reprocessed many times for a single request due to my use case.

@pfldy2850

This comment was marked as resolved.

@hmellor
Copy link
Member

hmellor commented Mar 25, 2024

@WoosukKwon is this on the roadmap?

@Andcircle
Copy link

what is current status on it =)
Also need this feature

@hmellor
Copy link
Member

hmellor commented Aug 2, 2024

According to #1265 (comment) this feature was added in #3042

@hmellor hmellor closed this as completed Aug 2, 2024
@Andcircle
Copy link

@hmellor
#3042 is not the feature we actually expected,

Is it possible we can add feature like:

llm = LLM(model="mistral", ...)
inputs_embeds = merge_inputs(texts, images)
#merge_inputs is a customized function, provided by user, this will make the process more flexible
outputs = llm.generate(inputs_embeds=inputs_embeds, ...)```

@hmellor hmellor reopened this Aug 3, 2024
@AnyangAngus
Copy link

@hmellor #3042 is not the feature we actually expected,

Is it possible we can add feature like:

llm = LLM(model="mistral", ...)
inputs_embeds = merge_inputs(texts, images)
#merge_inputs is a customized function, provided by user, this will make the process more flexible
outputs = llm.generate(inputs_embeds=inputs_embeds, ...)```

yeah, same request, input embedding in the llm.generate() function may be straightforward

@DarkLight1337
Copy link
Member

DarkLight1337 commented Aug 22, 2024

FYI this is now supported for multi-modal models via #6613. Perhaps a similar idea could be used to extend this to language-only models.

@Andcircle
Copy link

Andcircle commented Aug 22, 2024

FYI this is now supported for multi-modal models via #6613. Perhaps a similar idea could be used to extend this to language-only models.

Thanks for @DarkLight1337, thanks for the updates.
I checked the demo code, we still provide 2 modality separately, prompt and images, and the merge process is still controlled by only the VLLM supported VLM model, it is not that flexible if we wanna our own merge methods.

Can we do as following: so we just take customized VLM as a PURE language model

#start from pure language model, NOT existing VLM
llm = LLM(model="mistral", ...)
#merge_inputs is a customized function, provided by user, this will make the process more flexible
inputs_embeds = merge_inputs(texts, images)
#the LLM only takes batch of merged embeddings, it doesn't care is it image / video / audio anymore, it just take it as pure language model
outputs = llm.generate(inputs_embeds=inputs_embeds, ...)

I saw you also mentioned:
Follow-up TODO: Support initializing VLM with only language model backbone.
Is this as above mentioned?

Again, really really appreciated your help.

@fzyzcjy
Copy link
Contributor

fzyzcjy commented Nov 16, 2024

Hi, is there any updates? Thanks!

@DarkLight1337
Copy link
Member

Please refer to that PR for more info.

@v4if
Copy link

v4if commented Nov 21, 2024

any update?

@DaoD
Copy link

DaoD commented Nov 25, 2024

same request

@lyblsgo
Copy link

lyblsgo commented Dec 3, 2024

any update?

@Bryce1010 Bryce1010 linked a pull request Jan 2, 2025 that will close this issue
@sidhartha-roy
Copy link

is there an update on this?

groenenboomj pushed a commit to opendatahub-io/vllm that referenced this issue Feb 27, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature request New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.