Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Fused Mixtral support #8901

Merged
merged 10 commits into from
Aug 26, 2024
Merged

Conversation

penPenf28
Copy link
Contributor

@penPenf28 penPenf28 commented Aug 8, 2024

PR types

New features

PR changes

Models

Description

增加了高性能版本Mixtral-8x7B-Instruct-v0.1模型的支持,目前支持bfloat16+wint8,模型包括非block和block版本;

目前代码中包括一些冗余的量化部分,后续会进行修改添加相关的量化支持

Copy link

paddle-bot bot commented Aug 8, 2024

Thanks for your contribution!

Copy link

codecov bot commented Aug 8, 2024

Codecov Report

Attention: Patch coverage is 0% with 618 lines in your changes missing coverage. Please review.

Project coverage is 54.05%. Comparing base (d505a97) to head (7ed1917).
Report is 240 commits behind head on develop.

Files with missing lines Patch % Lines
...enlp/experimental/transformers/mixtral/modeling.py 0.00% 547 Missing ⚠️
...erimental/transformers/fused_transformer_layers.py 0.00% 69 Missing ⚠️
paddlenlp/experimental/transformers/__init__.py 0.00% 1 Missing ⚠️
...enlp/experimental/transformers/mixtral/__init__.py 0.00% 1 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #8901      +/-   ##
===========================================
- Coverage    54.80%   54.05%   -0.75%     
===========================================
  Files          647      650       +3     
  Lines       102474   104427    +1953     
===========================================
+ Hits         56157    56445     +288     
- Misses       46317    47982    +1665     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.


🚨 Try these New Features:

@penPenf28 penPenf28 marked this pull request as draft August 8, 2024 12:01
@penPenf28 penPenf28 marked this pull request as ready for review August 14, 2024 08:22
@penPenf28 penPenf28 force-pushed the fused_mixtral branch 3 times, most recently from 6070538 to 2b8afcf Compare August 19, 2024 03:06
return self.num_experts > 1

def use_moe(self, i: int) -> bool:
return self.has_moe() and (self.moe_every2 is False or (self.moe_every2 and i % 2 == 1))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个判断有点诡异,万一我是每隔四层换成moe layer呢。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不过只针对mixtral的话,暂时先这样吧。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,如果是有4,8...的需求,个人感觉可以把moe_every参数改为一个枚举,利用枚举来做判断,目前在做其他的支持,后续可以提交一个PR再修改

@DesmonDay
Copy link
Contributor

建议后面新增相关单测,确保功能正确性。 @penPenf28 @yuanlehome

@@ -1128,6 +1154,29 @@ def compute_out_linear(self, fmha_out, i):
weight_dtype=self.weight_dtype,
)

def compute_fused_moe(self, tmp_out, i):
# todo[xinhw]: make bias optional
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需尽早修复此bug

@@ -713,6 +794,29 @@ def compute_ffn_layernorm(self, out_linear_out, residual_input, i):

return tmp_out, residual_input

def compute_fused_moe(self, tmp_out, i):
# todo[xinhw]: make bias optional
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需尽早修复此bug

Copy link
Contributor

@DesmonDay DesmonDay left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@wawltor wawltor merged commit 31cc283 into PaddlePaddle:develop Aug 26, 2024
9 of 12 checks passed
Mangodadada pushed a commit to Mangodadada/PaddleNLP that referenced this pull request Sep 10, 2024
* [Feature] Fused Mixtral support

* [Refactor] add MoeConfig and fix static graph export problem

* [Bugfix] fix small bug

* [Bugfix] fix moe_config bug

* [Bugfix] fix moe_config bug

* [Refactor] refine code

* [Refactor] refine code

* [Refactor] refine code

* [Refactor] match fused moe api change

* [Feature] wint8 support
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants