|
|
from torch import nn, Tensor |
|
|
import math |
|
|
from typing import List, Optional, Dict, Tuple |
|
|
|
|
|
|
|
|
class MultiscaleMAE(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
scales: List[int] = [1, 2, 4], |
|
|
min_scale_weight: float = 0.0, |
|
|
max_scale_weight: float = 1.0, |
|
|
alpha: float = 0.5, |
|
|
weights: Optional[List[float]] = None, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
assert isinstance(scales, (list, tuple)) and len(scales) > 0 and all(isinstance(s, int) and s > 0 for s in scales), f"Expected scales to be a list of positive integers, got {scales}" |
|
|
assert max_scale_weight >= min_scale_weight >= 0, f"Expected max_scale_weight to be greater than or equal to min_scale_weight, got {min_scale_weight} and {max_scale_weight}" |
|
|
assert 1 > alpha > 0, f"Expected alpha to be between 0 and 1, got {alpha}" |
|
|
self.min_scale_weight, self.max_scale_weight = min_scale_weight, max_scale_weight |
|
|
|
|
|
scales = sorted(scales) |
|
|
weights = [min_scale_weight + (max_scale_weight - min_scale_weight) * alpha ** (math.log2(scales[-1] / s)) for s in scales] if weights is None else weights |
|
|
|
|
|
assert len(scales) == len(weights), f"Expected scales and weights to have the same length, got {len(scales)} and {len(weights)}" |
|
|
self.scales, self.weights = scales, weights |
|
|
|
|
|
for idx in range(len(scales)): |
|
|
setattr(self, f"pool_{scales[idx]}", nn.AvgPool2d(kernel_size=scales[idx], stride=scales[idx]) if scales[idx] > 1 else nn.Identity()) |
|
|
setattr(self, f"weight_{scales[idx]}", weights[idx]) |
|
|
setattr(self, f"mae_loss_fn_{scales[idx]}", nn.L1Loss(reduction="none")) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
pred_den_map: Tensor, |
|
|
gt_den_map: Tensor, |
|
|
) -> Tuple[Tensor, Dict]: |
|
|
assert len(pred_den_map.shape) == 4, f"Expected pred_den_map to have 4 dimensions, got {len(pred_den_map.shape)}" |
|
|
assert len(gt_den_map.shape) == 4, f"Expected gt_den_map to have 4 dimensions, got {len(gt_den_map.shape)}" |
|
|
assert pred_den_map.shape[1] == gt_den_map.shape[1] == 1, f"Expected pred_den_map and gt_den_map to have 1 channel, got {pred_den_map.shape[1]} and {gt_den_map.shape[1]}" |
|
|
assert pred_den_map.shape == gt_den_map.shape, f"Expected pred_den_map and gt_den_map to have the same shape, got {pred_den_map.shape} and {gt_den_map.shape}" |
|
|
|
|
|
loss, loss_info = 0, {} |
|
|
for idx in range(len(self.scales)): |
|
|
pool = getattr(self, f"pool_{self.scales[idx]}") |
|
|
weight = getattr(self, f"weight_{self.scales[idx]}") |
|
|
loss_fn = getattr(self, f"mae_loss_fn_{self.scales[idx]}") |
|
|
|
|
|
pred_den_map_pool = pool(pred_den_map) |
|
|
gt_den_map_pool = pool(gt_den_map) |
|
|
|
|
|
mae_loss_scale = loss_fn(pred_den_map_pool, gt_den_map_pool).sum(dim=(-1, -2)).mean() |
|
|
loss += weight * mae_loss_scale |
|
|
loss_info[f"mae_loss_{self.scales[idx]}"] = mae_loss_scale.detach() |
|
|
|
|
|
return loss, loss_info |
|
|
|