Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Support pooling #229

Merged
merged 1 commit into from
Mar 4, 2025
Merged

[Core] Support pooling #229

merged 1 commit into from
Mar 4, 2025

Conversation

wangxiyuan
Copy link
Collaborator

@wangxiyuan wangxiyuan commented Mar 3, 2025

This PR added pooling support for vllm-ascend

Tested with bge-base-en-v1.5 by encode:

from vllm import LLM

# Sample prompts.
prompts = [
  "Hello, my name is",
  "The president of the United States is",
  "The capital of France is",
  "The future of AI is",
]
# Create an LLM.
model = LLM(model="./bge-base-en-v1.5", enforce_eager=True)
# Generate embedding. The output is a list of EmbeddingRequestOutputs.
outputs = model.encode(prompts)
# Print the outputs.
for output in outputs:
    print(output.outputs.embedding)  # list of 4096 floats

Tested by embedding:

from vllm import LLM, SamplingParams

llm = LLM(model="./bge-base-en-v1.5", task="embed")
(output,) = llm.embed("Hello, my name is")

embeds = output.outputs.embedding
print(f"Embeddings: {embeds!r} (size={len(embeds)})")

Related: #200 #235

Known issue

The accuracy is not correct since this feature rely on enc-dec support. It'll be done in the following PR by @MengqingCao

@MengqingCao
Copy link
Contributor

Test LLM.score with BAAI/bge-reranker-v2-m3 locally, and raised NotImplementedError:

[rank0]:   File "/home/xxx/code/vllm-cpu/vllm/vllm/attention/layer.py", line 220, in forward
[rank0]:     return self.impl.forward(self, query, key, value,
[rank0]:   File "/home/xxx/code/vllm-ascend/vllm_ascend/attention.py", line 546, in forward
[rank0]:     raise NotImplementedError("Encoder self-attention and "
[rank0]: NotImplementedError: Encoder self-attention and encoder/decoder cross-attention are not implemented for AscendAttentionBackendImpl

Seems we need to add the support of encoder self-attention and encoder/decoder cross-attention.
But I'm fine with this pr for initial pooling model support, we can do the rest of the work in the following PRs

@wangxiyuan
Copy link
Collaborator Author

Test LLM.score with BAAI/bge-reranker-v2-m3 locally, and raised NotImplementedError:

[rank0]:   File "/home/xxx/code/vllm-cpu/vllm/vllm/attention/layer.py", line 220, in forward
[rank0]:     return self.impl.forward(self, query, key, value,
[rank0]:   File "/home/xxx/code/vllm-ascend/vllm_ascend/attention.py", line 546, in forward
[rank0]:     raise NotImplementedError("Encoder self-attention and "
[rank0]: NotImplementedError: Encoder self-attention and encoder/decoder cross-attention are not implemented for AscendAttentionBackendImpl

Seems we need to add the support of encoder self-attention and encoder/decoder cross-attention. But I'm fine with this pr for initial pooling model support, we can do the rest of the work in the following PRs

BAAI/bge-reranker-v2-m3 rely on Encode-only attention which is not support yet. I think we can do it in enc-dec feature.

@wangxiyuan wangxiyuan force-pushed the pooling branch 3 times, most recently from e50e772 to e682021 Compare March 4, 2025 03:03
Signed-off-by: wangxiyuan <[email protected]>
@github-actions github-actions bot added the documentation Improvements or additions to documentation label Mar 4, 2025
@wangxiyuan wangxiyuan merged commit ae49bfd into vllm-project:main Mar 4, 2025
11 checks passed
wangxiyuan added a commit to wangxiyuan/vllm-ascend that referenced this pull request Mar 4, 2025
This PR added pooling support for vllm-ascend

Tested with `bge-base-en-v1.5` by encode:
```
from vllm import LLM

prompts = [
  "Hello, my name is",
  "The president of the United States is",
  "The capital of France is",
  "The future of AI is",
]
model = LLM(model="./bge-base-en-v1.5", enforce_eager=True)
outputs = model.encode(prompts)
for output in outputs:
    print(output.outputs.embedding)  # list of 4096 floats
```

Tested by embedding:
```
from vllm import LLM, SamplingParams

llm = LLM(model="./bge-base-en-v1.5", task="embed")
(output,) = llm.embed("Hello, my name is")

embeds = output.outputs.embedding
print(f"Embeddings: {embeds!r} (size={len(embeds)})")
```

Related: vllm-project#200

The accuracy is not correct since this feature rely on `enc-dec`
support. It'll be done in the following PR by @MengqingCao

Signed-off-by: wangxiyuan <[email protected]>
wangxiyuan added a commit to wangxiyuan/vllm-ascend that referenced this pull request Mar 4, 2025
This PR added pooling support for vllm-ascend

Tested with `bge-base-en-v1.5` by encode:
```
from vllm import LLM

prompts = [
  "Hello, my name is",
  "The president of the United States is",
  "The capital of France is",
  "The future of AI is",
]
model = LLM(model="./bge-base-en-v1.5", enforce_eager=True)
outputs = model.encode(prompts)
for output in outputs:
    print(output.outputs.embedding)  # list of 4096 floats
```

Tested by embedding:
```
from vllm import LLM, SamplingParams

llm = LLM(model="./bge-base-en-v1.5", task="embed")
(output,) = llm.embed("Hello, my name is")

embeds = output.outputs.embedding
print(f"Embeddings: {embeds!r} (size={len(embeds)})")
```

Related: vllm-project#200

The accuracy is not correct since this feature rely on `enc-dec`
support. It'll be done in the following PR by @MengqingCao

Signed-off-by: wangxiyuan <[email protected]>
@wangxiyuan wangxiyuan deleted the pooling branch March 4, 2025 08:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation module:core
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants