【prim】update Prim optest about customvjp test, add FLAGS_comp_skip_default_ops #71269
+40
−15
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
PR Category
Execute Infrastructure
PR Types
Improvements
Description
pcard-67164
增加前向拆解默认跳过的list 包含 embedding, dropout 算子,原因是编译器不好处理embedding 拆解得到的gather 等算子, dropout 拆解的uniform和原始dropout的随机数生成无法对齐。
添加跳过前向拆解默认算子的flag, FLAGS_comp_skip_default_ops=True, 默认为true 表示默认不拆解上述算子,如需取消可以设置为false 与反向拆解FLAGS_prim_vjp_skip_default_ops 对应
修改prim_op_test 反向测试的逻辑,如果是custom_vjp 则反向需要跳过相应的前向算子拆解。 之前测试的是前向拆解结果自动微分得到的结果,修改为使用反向拆解的结果。