Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev #65

Merged
merged 2 commits into from
Oct 25, 2022
Merged

Dev #65

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ Picking a model and doing a forward pass is as simple as ...

## What's New?

- `FCN` and `DeepLabV3` added as new image segmentation models.
- `FCN`, `DeepLabV3` and `LRASPP` added as new image segmentation models.
- Backward incompatible changes to `v0.2.0` for loading a `pretrained` model.
- Almost all image classification models are ported from `torchvision`.
- New tutorial for generating `adversarial examples` and others coming soon.
Expand Down
12 changes: 12 additions & 0 deletions docs/api/models/segmentation/lraspp.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# LRASPP


::: eqxvision.models.LRASPP
selection:
members:
- __init__
- __call__

---

::: eqxvision.models.lraspp_mobilenet_v3_large
2 changes: 1 addition & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ pip install eqxvision
```

## What's New?
- `FCN`, `DeepLabV3` and `LRASPP` segmentation models are now supported (checkout the [tutorial](getting_started/FCN_Segmentation.ipynb)).
- Backward incompatible changes to `v0.2.0` for loading a `pretrained` model.
- `FCN` and `DeepLabV3` segmentation models are now supported (checkout the [tutorial](getting_started/FCN_Segmentation.ipynb)).
- Almost all image classification models are ported from `torchvision`.
- New tutorial for generating [adversarial examples](getting_started/Adversarial_Attack.ipynb) and others coming soon.

Expand Down
2 changes: 1 addition & 1 deletion eqxvision/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
r"""Root package info."""
__version__ = "0.2.5"
__version__ = "0.2.6"

from . import experimental, layers, models, utils
38 changes: 23 additions & 15 deletions eqxvision/experimental.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, Callable

import equinox as eqx
from jaxtyping import PyTree
import equinox.nn as nn


class AuxData:
Expand Down Expand Up @@ -33,7 +33,7 @@ def __call__(self, x, *, key=None):


def intermediate_layer_getter(
model: PyTree, get_target_layers: Callable
model: "eqx.Module", get_target_layers: Callable
) -> "eqx.Module":
"""Wraps intermediate layers of a model for accessing intermediate activations. Based on a discussion
[here](https://github.com/patrick-kidger/equinox/issues/186).
Expand All @@ -49,26 +49,34 @@ def intermediate_layer_getter(
of layers from the `model`

**Returns:**
A `PyTree`, encapsulating `model` for storing intermediate outputs from target layers.
The returned model will now return a `tuple` with

!!! info
The returned model will now return a `tuple` with
0. The final output of `model`
1. An ordered list of intermediate activations

1. The final output of `model`
2. An ordered list of intermediate activations
"""
target_layers = get_target_layers(model)
auxs, wrappers = zip(
*[_make_intermediate_layer_wrapper() for _ in range(len(target_layers))]
)
model = eqx.tree_at(
where=get_target_layers,
pytree=model,
replace=[
wrapper(target_layer)
for (wrapper, target_layer) in zip(wrappers, target_layers)
],
)
if isinstance(model, nn.Sequential):
new_modules, updated_count = [], 0
for idx, module in enumerate(model.layers):
if idx in target_layers:
new_modules.append(wrappers[updated_count](module))
updated_count += 1
else:
new_modules.append(module)
model = nn.Sequential(new_modules)
else:
model = eqx.tree_at(
where=get_target_layers,
pytree=model,
replace=[
wrapper(target_layer)
for (wrapper, target_layer) in zip(wrappers, target_layers)
],
)

class IntermediateLayerGetter(eqx.Module):
model: eqx.Module
Expand Down
1 change: 1 addition & 0 deletions eqxvision/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,4 @@
)
from .segmentation.deeplabv3 import DeepLabV3, deeplabv3
from .segmentation.fcn import FCN, fcn
from .segmentation.lraspp import LRASPP, lraspp_mobilenet_v3_large
5 changes: 4 additions & 1 deletion eqxvision/models/classification/mobilenetv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,10 @@ def mobilenet_v3_large(torch_weights: str = None, **kwargs: Any) -> MobileNetV3:

"""
arch = "mobilenet_v3_large"
inverted_residual_setting, last_channel = _mobilenet_v3_conf(arch, **kwargs)
dilated = kwargs.pop("dilated", False)
inverted_residual_setting, last_channel = _mobilenet_v3_conf(
arch, dilated=dilated, **kwargs
)
model = _mobilenet_v3(arch, inverted_residual_setting, last_channel, **kwargs)
if torch_weights:
model = load_torch_weights(model, torch_weights=torch_weights)
Expand Down
2 changes: 1 addition & 1 deletion eqxvision/models/segmentation/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from . import deeplabv3, fcn
from . import deeplabv3, fcn, lraspp
18 changes: 10 additions & 8 deletions eqxvision/models/segmentation/deeplabv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from jaxtyping import Array

from ...experimental import intermediate_layer_getter
from ...utils import CLASSIFICATION_URLS, load_torch_weights
from ...utils import load_torch_weights
from ..classification import resnet
from ._utils import _SimpleSegmentationModel
from .fcn import FCNHead
Expand Down Expand Up @@ -139,17 +139,17 @@ def deeplabv3(
num_classes: Optional[int] = 21,
backbone: "eqx.Module" = None,
intermediate_layers: Callable = None,
classifier_module: "eqx.Module" = DeepLabHead,
classifier_module: "eqx.Module" = None,
classifier_in_channels: int = 2048,
aux_classifier_module: "eqx.Module" = FCNHead,
aux_classifier_module: "eqx.Module" = None,
aux_in_channels: int = 1024,
silence_layers: Callable = None,
torch_weights: str = None,
*,
key: Optional["jax.random.PRNGKey"] = None,
) -> DeepLabV3:
"""Implements DeepLabV3 model from
["Rethinking Atrous Convolution for Semantic Image Segmentation"](https://arxiv.org/abs/1706.05587) paper.
[Rethinking Atrous Convolution for Semantic Image Segmentation](https://arxiv.org/abs/1706.05587) paper.

!!! info "Sample call"
```python
Expand All @@ -167,7 +167,7 @@ def deeplabv3(
- `num_classes`: Number of classes in the segmentation task.
Also controls the final output shape `(num_classes, height, width)`. Defaults to `21`
- `backbone`: The neural network to use for extracting features. If `None`, then all params are set to
`DeepLabV3_RESNET50` with a **pre-trained** backbone but **untrained** DeepLabV3 heads
`DeepLabV3_RESNET50` with `untrained` weights
- `intermediate_layers`: Layers from `backbone` to be used for generating output maps. Default sets it to
`layer3` and `layer4` from `DeepLabV3_RESNET50`
- `classifier_module`: Uses the `DeepLabHead` by default
Expand All @@ -179,15 +179,17 @@ def deeplabv3(
the `fc` layers can be dropped. This is particularly useful when loading weights from `torchvision`. By
default, `.fc` layer of a model is set to identity to avoid tracking weights.
- `torch_weights`: A `Path` or `URL` for the `PyTorch` weights. Defaults to `None`

- `key`: A `jax.random.PRNGKey` used to provide randomness for parameter
"""
if key is None:
key = jr.PRNGKey(0)
keys = jr.split(key, 2)

if not classifier_module:
classifier_module = DeepLabHead
if not aux_classifier_module:
aux_classifier_module = FCNHead
if backbone is None:
backbone = resnet.resnet50(
torch_weights=CLASSIFICATION_URLS["resnet50"],
replace_stride_with_dilation=[False, True, True],
)
num_layers = len(intermediate_layers(backbone))
Expand Down
12 changes: 6 additions & 6 deletions eqxvision/models/segmentation/fcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import jax.random as jr

from ...experimental import intermediate_layer_getter
from ...utils import CLASSIFICATION_URLS, load_torch_weights
from ...utils import load_torch_weights
from ..classification import resnet
from ._utils import _SimpleSegmentationModel

Expand Down Expand Up @@ -38,7 +38,7 @@ def fcn(
num_classes: Optional[int] = 21,
backbone: "eqx.Module" = None,
intermediate_layers: Callable = None,
classifier_module: "eqx.Module" = FCNHead,
classifier_module: "eqx.Module" = None,
classifier_in_channels: int = 2048,
aux_in_channels: int = None,
silence_layers: Callable = None,
Expand All @@ -64,7 +64,7 @@ def fcn(
- `num_classes`: Number of classes in the segmentation task.
Also controls the final output shape `(num_classes, height, width)`. Defaults to `21`
- `backbone`: The neural network to use for extracting features. If `None`, then all params are set to
`FCN_RESNET50` with a **pre-trained** backbone but an **untrained** FCN
`FCN_RESNET50` with `untrained` weights
- `intermediate_layers`: Layers from `backbone` to be used for generating output maps. Default sets it to
`layer3` and `layer4` from `FCN_RESNET50`
- `classifier_module`: Uses the `FCNHead` by default
Expand All @@ -75,15 +75,15 @@ def fcn(
the `fc` layers can be dropped. This is particularly useful when loading weights from `torchvision`. By
default, `.fc` layer of a model is set to identity to avoid tracking weights.
- `torch_weights`: A `Path` or `URL` for the `PyTorch` weights. Defaults to `None`

- `key`: A `jax.random.PRNGKey` used to provide randomness for parameter
"""
if key is None:
key = jr.PRNGKey(0)
keys = jr.split(key, 2)

if classifier_module is None:
classifier_module = FCNHead
if backbone is None:
backbone = resnet.resnet50(
torch_weights=CLASSIFICATION_URLS["resnet50"],
replace_stride_with_dilation=[False, True, True],
)
num_layers = len(intermediate_layers(backbone))
Expand Down
Loading