-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[NPU]Custom fusion operator unification #8431
Conversation
Thanks for your contribution! |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## develop #8431 +/- ##
===========================================
- Coverage 55.43% 55.42% -0.01%
===========================================
Files 616 617 +1
Lines 96243 96281 +38
===========================================
+ Hits 53348 53366 +18
- Misses 42895 42915 +20 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
paddlenlp/transformers/fusion_ops.py
Outdated
flash_attention = None | ||
|
||
|
||
def fusion_rope(query_states, key_states, value_states, hidden_states, position_ids, past_key_value, rotary_emb): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fusion_rope、fusion_flash_attention
这种太长了就不建议去抽取了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已经将paddlenlp/transformers/fusion_ops.py 移动到paddlenlp/transformers/llama/fusion_ops.py
@@ -81,14 +80,16 @@ def swiglu(x, y=None): | |||
|
|||
try: | |||
if get_env_device() == "npu": | |||
from paddle.base import core | |||
|
|||
for lib in os.listdir(os.getenv("CUSTOM_DEVICE_ROOT")): | |||
if lib.endswith(".so"): | |||
paddle.utils.cpp_extension.extension_utils.load_op_meta_info_and_register_op(lib) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
注意看是不是有不需要的代码,注意删除掉。
PR types
Others
PR changes
Models
Description
Custom fusion operator unification