Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 8th No.1】add lu_solve api for paddle -part #71030

Merged
merged 36 commits into from
Mar 10, 2025

Conversation

decade-afk
Copy link
Contributor

@decade-afk decade-afk commented Feb 6, 2025

PR Category

User Experience

PR Types

New features

Description

【Hackathon 8th No.1】add lu_solve api for paddle
docs: PaddlePaddle/docs#7052
rfc: PaddlePaddle/community#962

Copy link

paddle-bot bot commented Feb 6, 2025

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@decade-afk
Copy link
Contributor Author

我想请问下dcu芯片的编译是不支持cuslover库的吗?
image
@jeff41404 @luotao1

pivots (Tensor): Permutation matrix P of LU decomposition. It has shape :math:`(*, m)`, where :math:`*` is batch dimensions, that can be converted to a permutation matrix P, with data type int32.
trans (str): The transpose of the matrix A. It can be "N" , "T" or "C", default is "N".
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The meaning of "N", "T" or "C" needs to be clearly described, and corresponding sample code should be added to help users better understand.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment on lines 3579 to 3583
def lu_solve(
b: Tensor,
lu: Tensor,
pivots: Tensor,
trans: str = "N",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we support left to allow user to solve xA=b ?

Copy link
Contributor Author

@decade-afk decade-afk Feb 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, to achieve xA=b , you need to change it to A^Tx^T=b^T , and then combine it with paddle.lu to implement it

given LU decomposition :math:`A` and column vector :math:`b`.
Args:
b (Tensor): Column vector `b` in the above equation. It has shape :math:`(*, m, k)`, where :math:`*` is batch dimensions, with data type float32, float16.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also support float64?
and do we support type complex64/complex128?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The precision is aligned with the paddle.lu function. support float64 and float32

self.A_shape = [10, 10]
self.b_shape = [10, 5]
self.trans = "T"
self.dtype = "float64"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we support type of complex64/complex128? If so, we need to add a single test.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

complex64/complex128 are not supported

@decade-afk
Copy link
Contributor Author

The CI tests have been fully completed, please review. @jeff41404

jeff41404
jeff41404 previously approved these changes Mar 3, 2025
Copy link
Contributor

@jeff41404 jeff41404 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@decade-afk
Copy link
Contributor Author

image
请问这个一直waiting是正常情况吗 @sunzhongkai588

@luotao1
Copy link
Contributor

luotao1 commented Mar 4, 2025

这个是新加的CI,先等孙师傅review

Comment on lines 3591 to 3597
b (Tensor): Column vector `b` in the above equation. It has shape :math:`(*, m, k)`, where :math:`*` is batch dimensions, with data type float32, float64.
lu (Tensor): LU decomposition. It has shape :math:`(*, m, m)`, where :math:`*` is batch dimensions, that can be decomposed into an upper triangular matrix U and a lower triangular matrix L, with data type float32, float64.
pivots (Tensor): Permutation matrix P of LU decomposition. It has shape :math:`(*, m)`, where :math:`*` is batch dimensions, that can be converted to a permutation matrix P, with data type int32.
trans (str): The transpose of the matrix A. It can be "N" , "T" or "C", "N" means :math:`Ax=b`, "T" means :math:`A^Tx=b`, "C" means :math:`A^Hx=b`, default is "N".
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
b (Tensor): Column vector `b` in the above equation. It has shape :math:`(*, m, k)`, where :math:`*` is batch dimensions, with data type float32, float64.
lu (Tensor): LU decomposition. It has shape :math:`(*, m, m)`, where :math:`*` is batch dimensions, that can be decomposed into an upper triangular matrix U and a lower triangular matrix L, with data type float32, float64.
pivots (Tensor): Permutation matrix P of LU decomposition. It has shape :math:`(*, m)`, where :math:`*` is batch dimensions, that can be converted to a permutation matrix P, with data type int32.
trans (str): The transpose of the matrix A. It can be "N" , "T" or "C", "N" means :math:`Ax=b`, "T" means :math:`A^Tx=b`, "C" means :math:`A^Hx=b`, default is "N".
b (Tensor): Column vector `b` in the above equation. It has shape :math:`(*, m, k)`, where :math:`*` is batch dimensions, with data type float32, float64.
lu (Tensor): LU decomposition. It has shape :math:`(*, m, m)`, where :math:`*` is batch dimensions, that can be decomposed into an upper triangular matrix U and a lower triangular matrix L, with data type float32, float64.
pivots (Tensor): Permutation matrix P of LU decomposition. It has shape :math:`(*, m)`, where :math:`*` is batch dimensions, that can be converted to a permutation matrix P, with data type int32.
trans (str, optional)): The transpose of the matrix A. It can be "N" , "T" or "C", "N" means :math:`Ax=b`, "T" means :math:`A^Tx=b`, "C" means :math:`A^Hx=b`, default is "N".

另外,name 的参数描述也加一下

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@decade-afk 注意一下这里,参数间不要留空行,可选参数需要添加说明,参考 https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/dev_guides/api_contributing_guides/api_docs_guidelines_cn.html

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@decade-afk 注意一下这里,参数间不要留空行,可选参数需要添加说明,参考 https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/dev_guides/api_contributing_guides/api_docs_guidelines_cn.html

Done

name=None,
):
r"""
Computes the solution y to the system of linear equations :math:`Ax = b` ,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

y 是啥?好像下文都没有提及

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

y 是啥?好像下文都没有提及

已经都按照要求修改了

sunzhongkai588
sunzhongkai588 previously approved these changes Mar 7, 2025
Copy link
Contributor

@sunzhongkai588 sunzhongkai588 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Comment on lines 3580 to 3584
b: Tensor,
lu: Tensor,
pivots: Tensor,
trans: str = "N",
name=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这部分还有问题,辛苦 @SigureMo review一下

这个 PR 要不先合了吧,太大了


def test_static(self):
def run(place):
paddle.set_flags({'FLAGS_enable_pir_api': False})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么静态图只测老 IR?面向 PIR 新增 API 可以不写老 IR 分支,但不能完全不测 PIR

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

API 贡献指南有点过时了,我后面找时间更新下

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么静态图只测老 IR?面向 PIR 新增 API 可以不写老 IR 分支,但不能完全不测 PIR

我是根据lu的测试代码写的,而且有测试到新pir的吧。 从我另外一个pr那里就是目前ci他环境设置是 FLAGS_enable_pir_api=1

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我是根据lu的测试代码写的,而且有测试到新pir的吧。

这是怎么确定的呢?那边代码可并没有禁掉 PIR

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我是根据lu的测试代码写的,而且有测试到新pir的吧。

这是怎么确定的呢?那边代码可并没有禁掉 PIR

测试pir是需要设置export FLAGS_enable_pir_api=1对吧,我之前测试静态图的时候就弄过的,后面才改成export FLAGS_enable_pir_api=0,而且脚本那里没有关掉新pir功能。

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

之前测试静态图的时候就弄过的

不要之前,我们要的是监控,测过一遍并不能说明什么,你这里的代码能做到监控 PIR 么?后面 PIR 完全裸奔那这个 API 就有一半是白做的

测试pir是需要设置export FLAGS_enable_pir_api=1对吧

默认就是 PIR,不需要加 FLAG

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

之前测试静态图的时候就弄过的

不要之前,我们要的是监控,测过一遍并不能说明什么,你这里的代码能做到监控 PIR 么?后面 PIR 完全裸奔那这个 API 就有一半是白做的

image
那我是按照这个再写一个pir的测试代码?但是原本他就是pir模式的,我静态图代码之所以没有跑到就是因为,他跳到pir模式了

测试pir是需要设置export FLAGS_enable_pir_api=1对吧

默认就是 PIR,不需要加 FLAG

@@ -63,4 +63,5 @@
'smooth_l1_loss',
'spectral_norm',
'complex',
'lu_solve',
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NEED_TO_FIX_OP_LIST 看起来是要修复的列表,为什么新增 API 要加到这里呢?是有什么问题吗?这个 OP 看起来是有反向的

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NEED_TO_FIX_OP_LIST 看起来是要修复的列表,为什么新增 API 要加到这里呢?是有什么问题吗?这个 OP 看起来是有反向的

因为反向有一个算子 pivots 是不需要计算梯度的

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pivots 不是输入么?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pivots 不是输入么?

但他是主元信息,他是矩阵进行行变换的位置信息。

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我的问题是,如果将 API 加到这里,那反向是否有测到,因为代码明显实现了反向,但是是否有测试呢?

我只是说 pivots 应该是输入而不是 OP?

如果确认测到了,那就没问题

Copy link
Contributor Author

@decade-afk decade-afk Mar 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我的问题是,如果将 API 加到这里,那反向是否有测到,因为代码明显实现了反向,但是是否有测试呢?

我只是说 pivots 应该是输入而不是 OP?

如果确认测到了,那就没问题

有的有的,我在另一个pr那里写了的。就是写在impl那里检测不到
#71244
https://xly.bce.baidu.com/paddlepaddle/paddle/newipipe/detail/12424272/job/28789277
image

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

那这里我没问题了

Comment on lines 3661 to 3674
else:
check_variable_and_dtype(b, 'dtype', ['float32', 'float64'], 'lu_solve')
check_variable_and_dtype(
lu, 'dtype', ['float32', 'float64'], 'lu_solve'
)
check_variable_and_dtype(pivots, 'dtype', ['int32'], 'lu_solve')
helper = LayerHelper('lu_solve', **locals())
out = helper.create_variable_for_type_inference(dtype=b.dtype)
helper.append_op(
type='lu_solve',
inputs={'B': b, 'Lu': lu, 'Pivots': pivots},
outputs={'Out': out},
attrs={'trans': trans},
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如上所述,新增 API 这个分支可以不写,但不能为了过覆盖率直接把 PIR 静态图不测了

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如上所述,新增 API 这个分支可以不写,但不能为了过覆盖率直接把 PIR 静态图不测了

那么意思就是说静态图的代码是可以不写的? 新pir会自动动转静。另外我刚把另一个pr的报错去掉让他再跑了,让那个来检测我的静态代码是否有错误。
#71285

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

那么意思就是说静态图的代码是可以不写的? 新pir会自动动转静。

并不是不写,_C_ops.lu_solve 里自动生成的代码会进行动静分发,只是不需要手动写而已,仍然是动态图走动态图逻辑静态图走静态图逻辑,本质没变

另外我刚把另一个pr的报错去掉让他再跑了,让那个来检测我的静态代码是否有错误。

为啥不在这个 PR 直接改

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

那么意思就是说静态图的代码是可以不写的? 新pir会自动动转静。

并不是不写,_C_ops.lu_solve 里自动生成的代码会进行动静分发,只是不需要手动写而已,仍然是动态图走动态图逻辑静态图走静态图逻辑,本质没变

另外我刚把另一个pr的报错去掉让他再跑了,让那个来检测我的静态代码是否有错误。

为啥不在这个 PR 直接改

image
因为改的话,还需要加个export指令

b: Tensor,
lu: Tensor,
pivots: Tensor,
trans: str = "N",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果是可枚举的值,使用 Literal 来标注类型

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果是可枚举的值,使用 Literal 来标注类型

Done

lu: Tensor,
pivots: Tensor,
trans: str = "N",
name=None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
name=None,
name: str | None = None,

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的

@decade-afk
Copy link
Contributor Author

那请问目前我是要修改哪里? 还是要测试静态图代码? @SigureMo

Copy link
Member

@SigureMo SigureMo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

apply 下这个修改,如果 CI 没问题,那就没问题了

b (Tensor): Column vector `b` in the above equation. It has shape :math:`(*, m, k)`, where :math:`*` is batch dimensions, with data type float32, float64.
lu (Tensor): LU decomposition. It has shape :math:`(*, m, m)`, where :math:`*` is batch dimensions, that can be decomposed into an upper triangular matrix U and a lower triangular matrix L, with data type float32, float64.
pivots (Tensor): Permutation matrix P of LU decomposition. It has shape :math:`(*, m)`, where :math:`*` is batch dimensions, that can be converted to a permutation matrix P, with data type int32.
trans (str, optional): The transpose of the matrix A. It can be "N" , "T" or "C", "N" means :math:`Ax=b`, "T" means :math:`A^Tx=b`, "C" means :math:`A^Hx=b`, default is "N".
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

C 应该是为复数准备的吧,如果是实数 c 和 t 就意义一样了

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

C 应该是为复数准备的吧,如果是实数 c 和 t 就意义一样了

没错没错,因为我看lu那里也有预留复数的接口

@decade-afk
Copy link
Contributor Author

apply 下这个修改,如果 CI 没问题,那就没问题了

好嘞

@luotao1 luotao1 merged commit 5911a3a into PaddlePaddle:develop Mar 10, 2025
31 checks passed
@luotao1 luotao1 changed the title 【Hackathon 8th No.1】add lu_solve api for paddle 【Hackathon 8th No.1】add lu_solve api for paddle -part Mar 10, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants