-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor array_api
namespace, relying more directly on jax.numpy
#21013
Conversation
array_api
namespace, relying more directly on jax.numpy
4642197
to
36e53d9
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks really nice! A few comments below.
a146278
to
9e85f58
Compare
5c0b3de
to
f9dbcac
Compare
0897b7f
to
c4cfbcb
Compare
The current failure is due to needing to add some signature test skips for |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome work! This really makes clear what the remaining TODOs are 😀
This PR refactors the
jax.experimental.array_api
namespace by removing unnecessary wrappers around already-compliant functions in thejax.numpy
namespace, and structuring thearray_api
namespace to pull directly fromjax.numpy
whenever possible. After this PR, thearray_api
submodule contain only:jax.numpy
from breaking changes, which will be removed when the correspondingjax.numpy
behavior is deprecated and made array API compliantjax.numpy
and needs inclusion (e.g. introducingjax.numpy.matmul
, which already exists injax.numpy.linalg
).This PR also adds several
TODO
items describing what is required to cull that portion of thearray_api
submodule, with the understanding that once it is empty,jax.numpy
will be fully compliant. I figured it would be a bit neater to keep theTODO
notes dense in this submodule, rather than spreading them across thejax.numpy
submodule on their corresponding functions. It's also consistent with theTODO
s for new functionality or namespace elements.This PR also modifies
jax.numpy.isdtype
to accept_ScalarMeta
and other dtype-interpretable inputs.Note that the
array-api-tests
issues manyUserWarnings
for the special cases test, as well as for their reporting utilities due to not understanding what@jit
wrapped functions are in JAX, so this PR suppresses them in thejax-array-api
workflow.This PR has been validated against the
array-api-tests
suite for version2023.12
, usingjax/experimental/array_api/skips.txt
-- although it is worth noting that the test suite does not cover everything, e.g. is still missing support forcopy
anddevice
keyword tests.cc: @jakevdp