Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support albu transform #2943

Merged
merged 3 commits into from
Apr 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .circleci/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ jobs:
pip install mmcls==1.0.0rc6
pip install git+https://github.com/open-mmlab/mmdetection.git@main
pip install -r requirements/tests.txt -r requirements/optional.txt
python -m pip install albumentations>=0.3.2 --no-binary qudida,albumentations
- run:
name: Build and install
command: |
Expand Down Expand Up @@ -111,6 +112,7 @@ jobs:
docker exec mmseg pip install mmcls==1.0.0rc6
docker exec mmseg pip install -e /mmdetection
docker exec mmseg pip install -r requirements/tests.txt -r requirements/optional.txt
docker exec mmseg python -m pip install albumentations>=0.3.2 --no-binary qudida,albumentations
- run:
name: Build and install
command: |
Expand Down
4 changes: 2 additions & 2 deletions mmseg/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from .stare import STAREDataset
from .synapse import SynapseDataset
# yapf: disable
from .transforms import (CLAHE, AdjustGamma, BioMedical3DPad,
from .transforms import (CLAHE, AdjustGamma, Albu, BioMedical3DPad,
BioMedical3DRandomCrop, BioMedical3DRandomFlip,
BioMedicalGaussianBlur, BioMedicalGaussianNoise,
BioMedicalRandomGamma, GenerateEdge, LoadAnnotations,
Expand Down Expand Up @@ -51,5 +51,5 @@
'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur',
'BioMedicalRandomGamma', 'BioMedical3DPad', 'RandomRotFlip',
'SynapseDataset', 'REFUGEDataset', 'MapillaryDataset_v1',
'MapillaryDataset_v2'
'MapillaryDataset_v2', 'Albu'
]
4 changes: 2 additions & 2 deletions mmseg/datasets/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
LoadBiomedicalData, LoadBiomedicalImageFromFile,
LoadImageFromNDArray)
# yapf: disable
from .transforms import (CLAHE, AdjustGamma, BioMedical3DPad,
from .transforms import (CLAHE, AdjustGamma, Albu, BioMedical3DPad,
BioMedical3DRandomCrop, BioMedical3DRandomFlip,
BioMedicalGaussianBlur, BioMedicalGaussianNoise,
BioMedicalRandomGamma, GenerateEdge,
Expand All @@ -22,5 +22,5 @@
'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge',
'ResizeShortestEdge', 'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur',
'BioMedical3DRandomFlip', 'BioMedicalRandomGamma', 'BioMedical3DPad',
'RandomRotFlip'
'RandomRotFlip', 'Albu'
]
156 changes: 156 additions & 0 deletions mmseg/datasets/transforms/transforms.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import inspect
import warnings
from typing import Dict, List, Optional, Sequence, Tuple, Union

import cv2
import mmcv
import mmengine
import numpy as np
from mmcv.transforms.base import BaseTransform
from mmcv.transforms.utils import cache_randomness
Expand All @@ -15,6 +17,15 @@
from mmseg.datasets.dataset_wrappers import MultiImageMixDataset
from mmseg.registry import TRANSFORMS

try:
import albumentations
from albumentations import Compose
ALBU_INSTALLED = True
except ImportError:
albumentations = None
Compose = None
ALBU_INSTALLED = False


@TRANSFORMS.register_module()
class ResizeToMultiple(BaseTransform):
Expand Down Expand Up @@ -2135,3 +2146,148 @@ def __repr__(self):
repr_str += f'(prob={self.prob}, axes={self.axes}, ' \
f'swap_label_pairs={self.swap_label_pairs})'
return repr_str


@TRANSFORMS.register_module()
class Albu(BaseTransform):
"""Albumentation augmentation. Adds custom transformations from
Albumentations library. Please, visit
`https://albumentations.readthedocs.io` to get more information. An example
of ``transforms`` is as followed:

.. code-block::
[
dict(
type='ShiftScaleRotate',
shift_limit=0.0625,
scale_limit=0.0,
rotate_limit=0,
interpolation=1,
p=0.5),
dict(
type='RandomBrightnessContrast',
brightness_limit=[0.1, 0.3],
contrast_limit=[0.1, 0.3],
p=0.2),
dict(type='ChannelShuffle', p=0.1),
dict(
type='OneOf',
transforms=[
dict(type='Blur', blur_limit=3, p=1.0),
dict(type='MedianBlur', blur_limit=3, p=1.0)
],
p=0.1),
]
Args:
transforms (list[dict]): A list of albu transformations
keymap (dict): Contains {'input key':'albumentation-style key'}
update_pad_shape (bool): Whether to update padding shape according to \
the output shape of the last transform
"""

def __init__(self,
transforms: List[dict],
keymap: Optional[dict] = None,
update_pad_shape: bool = False):
if not ALBU_INSTALLED:
raise ImportError(
'albumentations is not installed, '
'we suggest install albumentation by '
'"pip install albumentations>=0.3.2 --no-binary qudida,albumentations"' # noqa
)

# Args will be modified later, copying it will be safer
transforms = copy.deepcopy(transforms)

self.transforms = transforms
self.keymap = keymap
self.update_pad_shape = update_pad_shape

self.aug = Compose([self.albu_builder(t) for t in self.transforms])

if not keymap:
self.keymap_to_albu = {
'img': 'image',
'gt_masks': 'masks',
}
else:
self.keymap_to_albu = copy.deepcopy(keymap)
self.keymap_back = {v: k for k, v in self.keymap_to_albu.items()}

def albu_builder(self, cfg: dict) -> object:
"""Build a callable object from a dict containing albu arguments.

Args:
cfg (dict): Config dict. It should at least contain the key "type".

Returns:
Callable: A callable object.
"""

assert isinstance(cfg, dict) and 'type' in cfg
args = cfg.copy()

obj_type = args.pop('type')
if mmengine.is_str(obj_type):
if not ALBU_INSTALLED:
raise ImportError(
'albumentations is not installed, '
'we suggest install albumentation by '
'"pip install albumentations>=0.3.2 --no-binary qudida,albumentations"' # noqa
)
obj_cls = getattr(albumentations, obj_type)
elif inspect.isclass(obj_type):
obj_cls = obj_type
else:
raise TypeError(
f'type must be a valid type or str, but got {type(obj_type)}')

if 'transforms' in args:
args['transforms'] = [
self.albu_builder(t) for t in args['transforms']
]

return obj_cls(**args)

@staticmethod
def mapper(d: dict, keymap: dict):
"""Dictionary mapper.

Renames keys according to keymap provided.
Args:
d (dict): old dict
keymap (dict): {'old_key':'new_key'}
Returns:
dict: new dict.
"""

updated_dict = {}
for k, _ in zip(d.keys(), d.values()):
new_k = keymap.get(k, k)
updated_dict[new_k] = d[k]
return updated_dict

def transform(self, results):
# dict to albumentations format
results = self.mapper(results, self.keymap_to_albu)

# Convert to RGB since Albumentations works with RGB images
results['image'] = cv2.cvtColor(results['image'], cv2.COLOR_BGR2RGB)

results = self.aug(**results)

# Convert back to BGR
results['image'] = cv2.cvtColor(results['image'], cv2.COLOR_RGB2BGR)

# back to the original format
results = self.mapper(results, self.keymap_back)

# update final shape
if self.update_pad_shape:
results['pad_shape'] = results['img'].shape

return results

def __repr__(self):
repr_str = self.__class__.__name__ + f'(transforms={self.transforms})'
return repr_str
58 changes: 58 additions & 0 deletions tests/test_datasets/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1160,3 +1160,61 @@ def test_biomedical_3d_flip():
results = transform(results)
assert np.equal(original_img, results['img']).all()
assert np.equal(original_seg, results['gt_seg_map']).all()


def test_albu_transform():
results = dict(
img_path=osp.join(osp.dirname(__file__), '../data/color.jpg'))

# Define simple pipeline
load = dict(type='LoadImageFromFile')
load = TRANSFORMS.build(load)

albu_transform = dict(
type='Albu', transforms=[dict(type='ChannelShuffle', p=1)])
albu_transform = TRANSFORMS.build(albu_transform)

normalize = dict(type='Normalize', mean=[0] * 3, std=[0] * 3, to_rgb=True)
normalize = TRANSFORMS.build(normalize)

# Execute transforms
results = load(results)
results = albu_transform(results)
results = normalize(results)

assert results['img'].dtype == np.float32


def test_albu_channel_order():
results = dict(
img_path=osp.join(osp.dirname(__file__), '../data/color.jpg'))

# Define simple pipeline
load = dict(type='LoadImageFromFile')
load = TRANSFORMS.build(load)

# Transform is modifying B channel
albu_transform = dict(
type='Albu',
transforms=[
dict(
type='RGBShift',
r_shift_limit=0,
g_shift_limit=0,
b_shift_limit=200,
p=1)
])
albu_transform = TRANSFORMS.build(albu_transform)

# Execute transforms
results_load = load(results)
results_albu = albu_transform(results_load)

# assert only Green and Red channel are not modified
np.testing.assert_array_equal(results_albu['img'][..., 1:],
results_load['img'][..., 1:])

# assert Blue channel is modified
with pytest.raises(AssertionError):
np.testing.assert_array_equal(results_albu['img'][..., 0],
results_load['img'][..., 0])