-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix synchronized memcpy in GPT #7008
Conversation
Thanks for your contribution! |
if loss_mask is None: | ||
loss_mask = (masked_lm_loss > 0).astype("float32") | ||
loss_mask = loss_mask.reshape([-1]) | ||
masked_lm_loss = paddle.sum(masked_lm_loss.reshape([-1]) * loss_mask) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Support for using custom loss mask?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的改动主要是因为 masked_lm_loss[masked_lm_loss > 0]
的写法会导致D2H的copy。改成loss_mask与lm_loss相乘,得到masked_lm_loss,两种是等效的,但不会有D2H copy。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是 slice op的原因吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里可能需要注意下 最后 masked_lm_loss 的数据类型,希望是 float32的
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
原实现masked_lm_loss = masked_lm_loss[masked_lm_loss > 0].astype("float32")
,返回的masked_lm_loss
的shape跟masked_lm_loss > 0
比较结果中True的个数有关,因此需要把masked_lm_loss > 0
比较结果中True的个数传回CPU,因此需要一个DtoH拷贝。masked_lm_loss[masked_lm_loss > 0]
的实现无法避免这个DtoH
的。
PR中的修改避开了getitem操作,实现了同样的功能,并且避免了DtoH拷贝。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里可能需要注意下 最后 masked_lm_loss 的数据类型,希望是 float32的
现在的写法应该可以确保是float32吧?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Codecov Report
@@ Coverage Diff @@
## develop #7008 +/- ##
===========================================
- Coverage 60.06% 59.90% -0.17%
===========================================
Files 552 554 +2
Lines 81755 81976 +221
===========================================
Hits 49105 49105
- Misses 32650 32871 +221
|
PR types
Performance optimization
PR changes
Others
Description
Avoid synchronized memcpy in GPT pretraining