-
Notifications
You must be signed in to change notification settings - Fork 123
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow passing numpy arrays to transpose #1258
base: main
Are you sure you want to change the base?
Conversation
tests/tensor/test_variable.py
Outdated
assert_array_equal(X.transpose(1, 0).eval({X: x}), x.transpose(1, 0)) | ||
|
||
# Test handing in lists and np.arrays | ||
assert_array_equal(X.transpose([1, 0]).eval({X: x}), x.transpose([1, 0])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to eval these, having confirmed the above works, you can use assert equal_computations([X.tranpose(array(...))], [X.transpose(list(...))])
that confirms you have the expected graph.
You can even check the two cases at once
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure how checking them at the same time would look like, but fixed code to do equal_computation now.
Also added the tuple alongside list and np.array as that was the third option
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can have multiple variables in each side of equal_computations
, such as equal_computations([var1, var2], [var3, var4])
, which will check if var1 is equal to var3 and var2 to var4
Description
Added np.ndarray as a typecheck along tuple and list for dimshuffle so tensor.transpose would accept numpy inputs.
Related Issue
Checklist
Type of change
📚 Documentation preview 📚: https://pytensor--1258.org.readthedocs.build/en/1258/