Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flax nnx ConvTranspose Does Not Restore Input Shape When Used with Conv (Unexpected Behavior) #4593

Open
Stella-S-Yan opened this issue Mar 4, 2025 · 2 comments

Comments

@Stella-S-Yan
Copy link

Stella-S-Yan commented Mar 4, 2025

When using Flax’s Conv and ConvTranspose layers in pair, the ConvTranspose does not seem to correctly restore the original input shape, even when the parameters are set in a way that should theoretically allow this. This behavior differs from PyTorch, where ConvXd and ConvTransposeXd used together reliably restore the input shape.

ConvTranspose function appears to produce incorrect output shapes, sometimes resulting in dimensions collapsing to zero. This behavior is not just a mismatch with PyTorch, but makes the function effectively unusable in certain cases.

Reproduction Example

from jax import random
from flax import nnx

import torch
from torch import nn

key = random.PRNGKey(42) 

batch_size = 4
in_channels = 128
out_channels = 32
i = 4
k = 3
s = 1
p = 0

# ============= Flax ===========================
x = random.uniform(key, shape=(batch_size, i, i, in_channels))
conv = nnx.Conv(in_features=in_channels,
                out_features=out_channels,
                kernel_size=(k, k), 
                strides=(s, s), 
                padding=p,
                rngs=nnx.Rngs(0))
y = conv(x)
print(y.shape) # (4, 2, 2, 32)
assert y.shape[2] == 2

tconv = nnx.ConvTranspose(in_features=out_channels,
                          out_features=in_channels,
                          kernel_size=(k, k), 
                            strides=(s, s), 
                            padding=p,
                            rngs=nnx.Rngs(0))
z = tconv(y)
print(z.shape) # (4, 0, 0, 128)
if z.shape[2] != i:
    print(f"Flax transConv failed to restore original input shape.")

# ============= PyTorch ========================
x = torch.rand(batch_size, in_channels, i, i)
conv = nn.Conv2d(in_channels=in_channels,
                 out_channels=out_channels,
                 kernel_size=k,
                 stride=s,
                 padding=p)
y = conv(x)
print(y.shape) # torch.Size([4, 32, 2, 2])
assert y.shape == (batch_size, out_channels, 2, 2)

kp = k
sp = s
pp = k - 1
ip = 2
op = ip + (k-1)
tconv = nn.ConvTranspose2d(in_channels=out_channels,
                           out_channels=in_channels,
                           kernel_size=k,
                           stride=s, 
                           padding=p)
z = tconv(y)
print(z.shape) # torch.Size([4, 128, 4, 4])
assert z.shape[2] == i

@cgarciae
Copy link
Collaborator

cgarciae commented Mar 4, 2025

Hey @Stella-S-Yan, using p = 'VALID' will give you the shapes you specify there:

batch_size = 4
in_channels = 128
out_channels = 32
i = 4
k = 3
s = 1
p = 'VALID'
key = jax.random.key(0)

# ============= Flax ===========================
x = jax.random.uniform(key, shape=(batch_size, i, i, in_channels))
print(f'{x.shape = }')
conv = nnx.Conv(
  in_features=in_channels,
  out_features=out_channels,
  kernel_size=(k, k),
  strides=(s, s),
  padding=p,
  rngs=nnx.Rngs(0),
)
y = conv(x)
print(y.shape)  # (4, 2, 2, 32)
# assert y.shape[2] == 2

tconv = nnx.ConvTranspose(
  in_features=out_channels,
  out_features=in_channels,
  kernel_size=(k, k),
  strides=(s, s),
  padding=p,
  rngs=nnx.Rngs(0),
)
z = tconv(y)
print(z.shape)  # (4, 0, 0, 128)
if z.shape[2] != i:
  print(f'Flax transConv failed to restore original input shape.')

@Stella-S-Yan
Copy link
Author

Stella-S-Yan commented Mar 4, 2025

Setting to "VALID" means zero padding. But explicitly setting to 0 padding gives wrong result, and this is something we should fix. Here is another example, that no matter setting to "VALID" or "SAME" or "CIRCULAR" will not work.

from jax import random
from flax import nnx

import torch
from torch import nn

key = random.PRNGKey(42) 

batch_size = 4
in_channels = 128
out_channels = 32
i = 5
k = 4
s = 1
p = 2

ip = 6

# ============= Flax ===========================
x = random.uniform(key, shape=(batch_size, i, i, in_channels))
conv = nnx.Conv(in_features=in_channels,
                out_features=out_channels,
                kernel_size=(k, k), 
                strides=(s, s), 
                padding=p,
                rngs=nnx.Rngs(0))
y = conv(x)
print(y.shape) 
assert y.shape[2] == ip

tconv = nnx.ConvTranspose(in_features=out_channels,
                          out_features=in_channels,
                          kernel_size=(k, k), 
                            strides=(s, s), 
                            padding="VALID",
                            rngs=nnx.Rngs(0))
z = tconv(y)
print(z.shape) 
if z.shape[2] != i:
    print(f"Flax transConv failed to restore original input shape.")


# ============= PyTorch ========================
x = torch.rand(batch_size, in_channels, i, i)
conv = nn.Conv2d(in_channels=in_channels,
                 out_channels=out_channels,
                 kernel_size=k,
                 stride=s,
                 padding=p)
y = conv(x)
print(y.shape) # torch.Size([4, 32, 2, 2])
assert y.shape[2] == ip 

kp = k
sp = s
pp = k - 1
ip = 2
op = ip + (k-1)
tconv = nn.ConvTranspose2d(in_channels=out_channels,
                           out_channels=in_channels,
                           kernel_size=k,
                           stride=s, 
                           padding=p)
z = tconv(y)
print(z.shape) 
assert z.shape[2] == i

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants