Skip to content

Commit

Permalink
Fix dice score when zero overlap between preds and target (#2860)
Browse files Browse the repository at this point in the history
* fix implementation

* add tests

* changelog

* introduce zero_division argument

* add tests

* changelog

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* less restrictive check

---------

Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
(cherry picked from commit 73cc585)
  • Loading branch information
SkafteNicki authored and Borda committed Feb 28, 2025
1 parent 2edaa31 commit 87b39df
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 9 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Added `zero_division` argument to `DiceScore` in segmentation package ([#2860](https://github.com/PyTorchLightning/metrics/pull/2860))


- Added `cache_session` to `DNSMOS` metric to control caching behavior ([#2974](https://github.com/PyTorchLightning/metrics/pull/2974))


Expand All @@ -30,6 +33,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- Fixed `DiceScore` when there is zero overlap between predictions and targets ([#2860](https://github.com/PyTorchLightning/metrics/pull/2860))


- Fix `MeanAveragePrecision` for `average="micro"` when 0 label is not present ([#2968](https://github.com/PyTorchLightning/metrics/pull/2968))


Expand Down
16 changes: 13 additions & 3 deletions src/torchmetrics/functional/segmentation/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from typing import Optional, Union

import torch
from torch import Tensor
Expand All @@ -27,6 +27,7 @@ def _dice_score_validate_args(
include_background: bool,
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
input_format: Literal["one-hot", "index"] = "one-hot",
zero_divide: Union[float, Literal["warn", "nan"]] = 1.0,
) -> None:
"""Validate the arguments of the metric."""
if not isinstance(num_classes, int) or num_classes <= 0:
Expand All @@ -38,6 +39,10 @@ def _dice_score_validate_args(
raise ValueError(f"Expected argument `average` to be one of {allowed_average} or None, but got {average}.")
if input_format not in ["one-hot", "index"]:
raise ValueError(f"Expected argument `input_format` to be one of 'one-hot', 'index', but got {input_format}.")
if zero_divide not in [1.0, 0.0, "warn", "nan"]:
raise ValueError(
f"Expected argument `zero_divide` to be one of 1.0, 0.0, 'warn', 'nan', but got {zero_divide}."
)


def _dice_score_update(
Expand Down Expand Up @@ -76,16 +81,21 @@ def _dice_score_compute(
denominator: Tensor,
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
support: Optional[Tensor] = None,
zero_division: Union[float, Literal["warn", "nan"]] = 1.0,
) -> Tensor:
"""Compute the Dice score from the numerator and denominator."""
# If both numerator and denominator are 0, the dice score is 0
if torch.all(numerator == 0) and torch.all(denominator == 0):
return torch.tensor(0.0, device=numerator.device, dtype=torch.float)

if average == "micro":
numerator = torch.sum(numerator, dim=-1)
denominator = torch.sum(denominator, dim=-1)
dice = _safe_divide(numerator, denominator, zero_division=1.0)
dice = _safe_divide(numerator, denominator, zero_division=zero_division)
if average == "macro":
dice = torch.mean(dice, dim=-1)
elif average == "weighted" and support is not None:
weights = _safe_divide(support, torch.sum(support, dim=-1, keepdim=True), zero_division=1.0)
weights = _safe_divide(support, torch.sum(support, dim=-1, keepdim=True), zero_division=zero_division)
dice = torch.sum(dice * weights, dim=-1)
return dice

Expand Down
9 changes: 7 additions & 2 deletions src/torchmetrics/segmentation/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ class DiceScore(Metric):
or ``None``. This determines how to average the dice score across different classes.
input_format: What kind of input the function receives. Choose between ``"one-hot"`` for one-hot encoded tensors
or ``"index"`` for index tensors
zero_division: The value to return when there is a division by zero. Options are 1.0, 0.0, "warn" or "nan".
Setting it to "warn" behaves like 0.0 but will also create a warning.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Raises:
Expand Down Expand Up @@ -110,14 +112,16 @@ def __init__(
include_background: bool = True,
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
input_format: Literal["one-hot", "index"] = "one-hot",
zero_division: Union[float, Literal["warn", "nan"]] = 0.0,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
_dice_score_validate_args(num_classes, include_background, average, input_format)
_dice_score_validate_args(num_classes, include_background, average, input_format, zero_division)
self.num_classes = num_classes
self.include_background = include_background
self.average = average
self.input_format = input_format
self.zero_division = zero_division

num_classes = num_classes - 1 if not include_background else num_classes
self.add_state("numerator", [], dist_reduce_fx="cat")
Expand All @@ -140,7 +144,8 @@ def compute(self) -> Tensor:
dim_zero_cat(self.denominator),
self.average,
support=dim_zero_cat(self.support) if self.average == "weighted" else None,
).mean(dim=0)
zero_division=self.zero_division,
).nanmean(dim=0)

def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Expand Down
19 changes: 15 additions & 4 deletions src/torchmetrics/utilities/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from typing import Optional, Union

import torch
from torch import Tensor
from typing_extensions import Literal

from torchmetrics.utilities import rank_zero_warn


def _safe_matmul(x: Tensor, y: Tensor) -> Tensor:
"""Safe calculation of matrix multiplication.
Expand Down Expand Up @@ -44,7 +46,11 @@ def _safe_xlogy(x: Tensor, y: Tensor) -> Tensor:
return res


def _safe_divide(num: Tensor, denom: Tensor, zero_division: float = 0.0) -> Tensor:
def _safe_divide(
num: Tensor,
denom: Tensor,
zero_division: Union[float, Literal["warn", "nan"]] = 0.0,
) -> Tensor:
"""Safe division, by preventing division by zero.
Function will cast to float if input is not already to secure backwards compatibility.
Expand All @@ -64,8 +70,13 @@ def _safe_divide(num: Tensor, denom: Tensor, zero_division: float = 0.0) -> Tens
"""
num = num if num.is_floating_point() else num.float()
denom = denom if denom.is_floating_point() else denom.float()
zero_division_tensor = torch.tensor(zero_division, dtype=num.dtype).to(num.device, non_blocking=True)
return torch.where(denom != 0, num / denom, zero_division_tensor)
if isinstance(zero_division, (float, int)) or zero_division == "warn":
if zero_division == "warn" and torch.any(denom == 0):
rank_zero_warn("Detected zero division in _safe_divide. Setting 0/0 to 0.0")
zero_division = 0.0 if zero_division == "warn" else zero_division
zero_division_tensor = torch.tensor(zero_division, dtype=num.dtype).to(num.device, non_blocking=True)
return torch.where(denom != 0, num / denom, zero_division_tensor)
return torch.true_divide(num, denom)


def _adjust_weights_safe_divide(
Expand Down
43 changes: 43 additions & 0 deletions tests/unittests/segmentation/test_dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import pytest
import torch
from sklearn.metrics import f1_score
from torch import tensor

from torchmetrics import MetricCollection
from torchmetrics.functional.segmentation.dice import dice_score
Expand Down Expand Up @@ -109,6 +110,48 @@ def test_dice_score_functional(self, preds, target, input_format, include_backgr
)


@pytest.mark.parametrize("average", ["micro", "macro", "weighted", None])
def test_corner_case_no_overlap(average):
"""Check that if no overlap and intersection between target and preds, the dice score is 0.
See issue: https://github.com/Lightning-AI/torchmetrics/issues/2851
"""
target = torch.full((4, 4, 128, 128), 0, dtype=torch.int8)
preds = torch.full((4, 4, 128, 128), 0, dtype=torch.int8)
target[0, 0] = 1
preds[0, 0] = 1
dice = DiceScore(num_classes=3, average=average, include_background=False)
assert dice(preds, target) == 0.0


@pytest.mark.parametrize("average", ["micro", "macro", "weighted", None])
@pytest.mark.parametrize("zero_division", [1.0, 0.0, "warn", "nan"])
def test_zero_division(zero_division, average):
"""Test different zero_division values."""
target = torch.full((1, 3, 128, 128), 0, dtype=torch.int8)
preds = torch.full((1, 3, 128, 128), 0, dtype=torch.int8)
target[0, 0] = 1
dice = DiceScore(num_classes=3, average=average, zero_division=zero_division)
score = dice(preds, target)

res_dict = {
"micro": {1.0: tensor(0.0), 0.0: tensor(0.0), "warn": tensor(0.0), "nan": tensor(0.0)},
"macro": {1.0: tensor(0.6667), 0.0: tensor(0.0), "warn": tensor(0.0), "nan": tensor(float("nan"))},
"weighted": {1.0: tensor(0.0), 0.0: tensor(0.0), "warn": tensor(0.0), "nan": tensor(float("nan"))},
None: {
1.0: tensor([0.0, 1.0, 1.0]),
0.0: tensor([0.0, 0.0, 0.0]),
"warn": tensor([0.0, 0.0, 0.0]),
"nan": tensor([0.0, float("nan"), float("nan")]),
},
}

assert torch.allclose(score, res_dict[average][zero_division], atol=1e-4, equal_nan=True), (
f"Expected {res_dict[average][zero_division]} but got {score}"
)


@pytest.mark.parametrize("compute_groups", [True, False])
def test_dice_score_metric_collection(compute_groups: bool, num_batches: int = 4):
"""Test that the metric works within a metric collection with and without compute groups."""
Expand Down

0 comments on commit 87b39df

Please sign in to comment.