-
Notifications
You must be signed in to change notification settings - Fork 679
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Flax nnx ConvTranspose Does Not Restore Input Shape When Used with Conv (Unexpected Behavior) #4593
Comments
Hey @Stella-S-Yan, using batch_size = 4
in_channels = 128
out_channels = 32
i = 4
k = 3
s = 1
p = 'VALID'
key = jax.random.key(0)
# ============= Flax ===========================
x = jax.random.uniform(key, shape=(batch_size, i, i, in_channels))
print(f'{x.shape = }')
conv = nnx.Conv(
in_features=in_channels,
out_features=out_channels,
kernel_size=(k, k),
strides=(s, s),
padding=p,
rngs=nnx.Rngs(0),
)
y = conv(x)
print(y.shape) # (4, 2, 2, 32)
# assert y.shape[2] == 2
tconv = nnx.ConvTranspose(
in_features=out_channels,
out_features=in_channels,
kernel_size=(k, k),
strides=(s, s),
padding=p,
rngs=nnx.Rngs(0),
)
z = tconv(y)
print(z.shape) # (4, 0, 0, 128)
if z.shape[2] != i:
print(f'Flax transConv failed to restore original input shape.') |
Setting to "VALID" means zero padding. But explicitly setting to 0 padding gives wrong result, and this is something we should fix. Here is another example, that no matter setting to "VALID" or "SAME" or "CIRCULAR" will not work.
|
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
When using Flax’s
Conv
andConvTranspose
layers in pair, theConvTranspose
does not seem to correctly restore the original input shape, even when the parameters are set in a way that should theoretically allow this. This behavior differs from PyTorch, whereConvXd
andConvTransposeXd
used together reliably restore the input shape.ConvTranspose
function appears to produce incorrect output shapes, sometimes resulting in dimensions collapsing to zero. This behavior is not just a mismatch with PyTorch, but makes the function effectively unusable in certain cases.Reproduction Example
The text was updated successfully, but these errors were encountered: