File size: 9,428 Bytes
0ecb9aa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 |
import torch
from torch import nn, Tensor
import torch.nn.functional as F
from typing import List, Tuple, Dict
from .dm_loss import DMLoss
from .multiscale_mae import MultiscaleMAE
from .utils import _reshape_density
class DualLoss(nn.Module):
def __init__(
self,
input_size: int,
block_size: int,
bins: List[Tuple[float, float]],
bin_centers: List[float],
cls_loss: str = "ce",
reg_loss: str = "dm",
weight_tv: float = 0.01,
weight_cls: float = 0.1,
weight_reg: float = 0.1,
numItermax: int = 100,
regularization: float = 10.0,
scales: List[int] = [1, 2, 4],
min_scale_weight: float = 0.25,
max_scale_weight: float = 0.75,
alpha: float = 0.5,
) -> None:
super().__init__()
assert len(bins) == len(bin_centers) >= 2, f"Expected bins and bin_centers to have at least 2 elements, got {len(bins)} and {len(bin_centers)}"
assert all([len(b) == 2 for b in bins]), f"Expected all bins to be of length 2, got {bins}"
assert all(b[0] <= p <= b[1] for b, p in zip(bins, bin_centers)), f"Expected bin_centers to be within the range of the corresponding bin, got {bins} and {bin_centers}"
assert cls_loss in ["ce", "mae", "mse", "none"], f"Expected cls_loss to be one of ['ce', 'mae', 'mse', 'none'], got {cls_loss}"
assert reg_loss in ["dm", "msmae", "mae", "mse", "none"], f"Expected reg_loss to be one of ['dm', 'msmae', 'mae', 'mse', 'none'], got {reg_loss}"
assert not (cls_loss == "none" and reg_loss == "none"), "Expected at least one of cls_loss and reg_loss to be provided"
assert weight_cls is None or weight_cls >= 0, f"Expected weight_cls to be non-negative, got {weight_cls}"
assert weight_reg is None or weight_reg >= 0, f"Expected weight_reg to be non-negative, got {weight_reg}"
assert weight_tv is None or weight_tv >= 0, f"Expected weight_tv to be non-negative, got {weight_tv}"
assert min_scale_weight is None or max_scale_weight is None or max_scale_weight >= min_scale_weight > 0, f"Expected max_scale_weight to be greater than or equal to min_scale_weight, got {min_scale_weight} and {max_scale_weight}"
assert alpha is None or 1 > alpha > 0, f"Expected alpha to be between 0 and 1, got {alpha}"
if reg_loss == "dm":
assert numItermax is not None and numItermax > 0, f"Expected numItermax to be a positive integer, got {numItermax}"
assert regularization is not None and regularization > 0, f"Expected regularization to be a positive float, got {regularization}"
assert weight_tv is not None and weight_tv >= 0, f"Expected weight_tv to be non-negative, got {weight_tv}"
else:
weight_tv, numItermax, regularization = None, None, None
if reg_loss == "msmae":
assert isinstance(scales, (list, tuple)) and len(scales) > 0 and all(isinstance(s, int) and s > 0 for s in scales), f"Expected scales to be a list of positive integers, got {scales}"
assert max_scale_weight >= min_scale_weight > 0, f"Expected max_scale_weight to be greater than or equal to min_scale_weight, got {min_scale_weight} and {max_scale_weight}"
assert 1 > alpha > 0, f"Expected alpha to be between 0 and 1, got {alpha}"
else:
scales = None
min_scale_weight, max_scale_weight = None, None
alpha = None
weight_cls = weight_cls if weight_cls is not None else 0
weight_reg = weight_reg if weight_reg is not None else 0
self.input_size, self.block_size = input_size, block_size
self.num_blocks_h, self.num_blocks_w = input_size // block_size, input_size // block_size
self.bins, self.bin_centers, self.num_bins = bins, bin_centers, len(bins)
self.cls_loss, self.reg_loss = cls_loss, reg_loss
self.weight_cls, self.weight_reg = weight_cls, weight_reg
self.numItermax, self.regularization = numItermax, regularization
self.weight_tv = weight_tv
self.scales = scales
self.min_scale_weight, self.max_scale_weight = min_scale_weight, max_scale_weight
if cls_loss == "ce":
self.cls_loss_fn = nn.CrossEntropyLoss(reduction="none")
self.weight_cls = 1.0
elif cls_loss == "mae":
self.cls_loss_fn = nn.L1Loss(reduction="none")
self.weight_cls = weight_cls
elif cls_loss == "mse":
self.cls_loss_fn = nn.MSELoss(reduction="none")
self.weight_cls = weight_cls
else: # cls_loss == "none"
self.cls_loss_fn = None
self.weight_cls = 0
if reg_loss == "dm":
self.reg_loss_fn = DMLoss(
input_size=input_size,
block_size=block_size,
numItermax=numItermax,
regularization=regularization,
weight_ot=weight_reg,
weight_tv=weight_tv,
weight_cnt=0, # Calculate the count loss separately
)
self.weight_reg = 1.0
elif reg_loss == "msmae":
self.reg_loss_fn = MultiscaleMAE(scales=scales, weights=None, min_scale_weight=min_scale_weight, max_scale_weight=max_scale_weight, alpha=alpha)
self.weight_reg = 1.0
elif reg_loss == "mae":
self.reg_loss_fn = nn.L1Loss(reduction="none")
self.weight_reg = weight_reg
elif reg_loss == "mse":
self.reg_loss_fn = nn.MSELoss(reduction="none")
self.weight_reg = weight_reg
else:
self.reg_loss_fn = None
self.weight_reg = 0
self.cnt_loss_fn = nn.L1Loss(reduction="none")
def _bin_count(self, density_map: Tensor) -> Tensor:
class_map = torch.zeros_like(density_map, dtype=torch.long)
for idx, (low, high) in enumerate(self.bins):
mask = (density_map >= low) & (density_map <= high)
class_map[mask] = idx
return class_map.squeeze(1) # remove channel dimension
def forward(
self,
pred_logit_map: Tensor,
pred_den_map: Tensor,
gt_den_map: Tensor,
gt_points: List[Tensor]
) -> Tuple[Tensor, Dict[str, Tensor]]:
B = pred_logit_map.shape[0]
assert pred_logit_map.shape == (B, self.num_bins, self.num_blocks_h, self.num_blocks_w), f"Expected pred_logit_map to have shape {B, self.num_bins, self.num_blocks_h, self.num_blocks_w}, got {pred_logit_map.shape}"
if gt_den_map.shape[-2:] != (self.num_blocks_h, self.num_blocks_w):
assert gt_den_map.shape[-2:] == (self.input_size, self.input_size), f"Expected gt_den_map to have shape {B, 1, self.input_size, self.input_size}, got {gt_den_map.shape}"
gt_den_map = _reshape_density(gt_den_map, block_size=self.block_size)
assert pred_den_map.shape == gt_den_map.shape == (B, 1, self.num_blocks_h, self.num_blocks_w), f"Expected pred_den_map and gt_den_map to have shape (B, 1, H, W), got {pred_den_map.shape} and {gt_den_map.shape}"
assert len(gt_points) == B, f"Expected gt_points to have length B, got {len(gt_points)}"
loss_info = {}
if self.weight_cls > 0:
gt_class_map = self._bin_count(gt_den_map)
if self.cls_loss == "ce":
cls_loss = self.cls_loss_fn(pred_logit_map, gt_class_map).sum(dim=(-1, -2)).mean()
loss_info["cls_ce_loss"] = cls_loss.detach()
else: # self.cls_loss in ["mae", "mse"]
gt_prob_map = F.one_hot(gt_class_map, num_classes=self.num_bins).float() # B, H, W -> B, H, W, N
gt_prob_map = gt_prob_map.permute(0, 3, 1, 2) # B, H, W, N -> B, N, H, W
pred_prob_map = pred_logit_map.softmax(dim=1)
cls_loss = self.cls_loss_fn(pred_prob_map, gt_prob_map).sum(dim=(-1, -2)).mean()
loss_info[f"cls_{self.cls_loss}_loss"] = cls_loss.detach()
else:
cls_loss = 0
if self.weight_reg > 0:
if self.reg_loss == "dm":
reg_loss, reg_loss_info = self.reg_loss_fn(
pred_den_map=pred_den_map,
gt_den_map=gt_den_map,
gt_points=gt_points,
)
loss_info.update({f"reg_{k}": v for k, v in reg_loss_info.items()})
elif self.reg_loss == "msmae":
reg_loss, reg_loss_info = self.reg_loss_fn(pred_den_map, gt_den_map)
loss_info.update({f"reg_{k}": v for k, v in reg_loss_info.items()})
else: # self.reg_loss in ["mae", "mse"]
reg_loss = self.reg_loss_fn(pred_den_map, gt_den_map).sum(dim=(-1, -2)).mean()
loss_info[f"reg_{self.reg_loss}_loss"] = reg_loss.detach()
else:
reg_loss = 0
gt_cnt = torch.tensor([len(p) for p in gt_points], dtype=torch.float32, device=pred_den_map.device)
cnt_loss = self.cnt_loss_fn(pred_den_map.sum(dim=(1, 2, 3)), gt_cnt).mean()
loss_info["cnt_loss"] = cnt_loss.detach()
total_loss = self.weight_cls * cls_loss + self.weight_reg * reg_loss + cnt_loss
loss_info["total_loss"] = total_loss.detach()
loss_info = dict(sorted(loss_info.items())) # sort by key for nicer printing
return total_loss, loss_info |