Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vit fix #8

Merged
merged 2 commits into from
Jul 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion eqxvision/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
r"""Root package info."""
__version__ = "0.0.3"
__version__ = "0.0.4"

from . import layers, models
2 changes: 1 addition & 1 deletion eqxvision/models/classification/googlenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def forward(net, x, key):
inception_block = blocks[1]
inception_aux_block = blocks[2]

if not key:
if key is None:
key = jrandom.PRNGKey(0)
keys = jrandom.split(key, 20)

Expand Down
18 changes: 10 additions & 8 deletions eqxvision/models/classification/resnet.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, List, Optional, Sequence, Type, Union
from typing import Any, Callable, List, Optional, Sequence, Type, Union

import equinox as eqx
import equinox.experimental as eqex
Expand Down Expand Up @@ -179,7 +179,6 @@ def __call__(self, x: Array, *, key: Optional["jax.random.PRNGKey"] = None):
class ResNet(eqx.Module):
"""A simple port of torchvision.models.resnet"""

_norm_layer: Callable
inplanes: int
dilation: int
groups: Sequence[int]
Expand All @@ -204,7 +203,7 @@ def __init__(
groups: int = 1,
width_per_group: int = 64,
replace_stride_with_dilation: List[bool] = None,
norm_layer: eqx.Module = None,
norm_layer: Any = None,
*,
key: Optional["jax.random.PRNGKey"] = None,
):
Expand Down Expand Up @@ -247,11 +246,10 @@ def forward(net, x, key):
raise NotImplementedError(
f"{type(norm_layer)} is not currently supported. Use `eqx.experimental.BatchNorm` instead."
)
if not key:
if key is None:
key = jrandom.PRNGKey(0)

keys = jrandom.split(key, 6)
self._norm_layer = norm_layer
self.inplanes = 64
self.dilation = 1
if replace_stride_with_dilation is None:
Expand All @@ -277,11 +275,12 @@ def forward(net, x, key):
self.bn1 = norm_layer(input_size=self.inplanes, axis_name="batch")
self.relu = jnn.relu
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0], key=keys[1])
self.layer1 = self._make_layer(block, 64, layers[0], norm_layer, key=keys[1])
self.layer2 = self._make_layer(
block,
128,
layers[1],
norm_layer,
stride=2,
dilate=replace_stride_with_dilation[0],
key=keys[2],
Expand All @@ -290,6 +289,7 @@ def forward(net, x, key):
block,
256,
layers[2],
norm_layer,
stride=2,
dilate=replace_stride_with_dilation[1],
key=keys[3],
Expand All @@ -298,6 +298,7 @@ def forward(net, x, key):
block,
512,
layers[3],
norm_layer,
stride=2,
dilate=replace_stride_with_dilation[2],
key=keys[4],
Expand All @@ -306,9 +307,10 @@ def forward(net, x, key):
self.fc = nn.Linear(512 * EXPANSIONS[block], num_classes, key=keys[5])
# TODO: Zero initialize BNs as per torchvision

def _make_layer(self, block, planes, blocks, stride=1, dilate=False, key=None):
def _make_layer(
self, block, planes, blocks, norm_layer, stride=1, dilate=False, key=None
):
keys = jrandom.split(key, blocks + 1)
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate:
Expand Down
8 changes: 5 additions & 3 deletions eqxvision/models/classification/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,9 @@ def forward(net, x, keys):
"""

super().__init__()
if not key:
if key is None:
key = jrandom.PRNGKey(0)
keys = jrandom.split(key, depth + 2)
keys = jrandom.split(key, depth + 3)
self.inference = False
self.num_features = embed_dim
self.patch_embed = PatchEmbed(
Expand Down Expand Up @@ -259,7 +259,9 @@ def forward(net, x, keys):
self.norm = norm_layer(embed_dim)
# Classifier head
self.fc = (
nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
nn.Linear(embed_dim, num_classes, key=keys[-1])
if num_classes > 0
else nn.Identity()
)
# ToDo: Initialization scheme of the weights

Expand Down
8 changes: 4 additions & 4 deletions tests/test_models/test_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,23 +110,23 @@ def forward(net, x, keys):
return jax.vmap(net)(x, key=keys)

random_input = jax.random.uniform(key=getkey(), shape=(1, 3, 224, 224))
answer = (1, 192)
net = models.vit_tiny()
answer = (1, 1000)
net = models.vit_tiny(num_classes=1000)
keys = jax.random.split(getkey(), random_input.shape[0])

output = forward(net, random_input, keys)
assert output.shape == answer
assert c_counter == 1

answer = (1, 384)
net = models.vit_small()
net = models.vit_small(num_classes=0)

output = forward(net, random_input, keys)
assert output.shape == answer
assert c_counter == 2

answer = (1, 768)
net = models.vit_base()
net = models.vit_base(num_classes=0)

output = forward(net, random_input, keys)
assert output.shape == answer
Expand Down