File size: 3,126 Bytes
0ecb9aa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
from torch import nn, Tensor
import math
from typing import List, Optional, Dict, Tuple
class MultiscaleMAE(nn.Module):
def __init__(
self,
scales: List[int] = [1, 2, 4],
min_scale_weight: float = 0.0,
max_scale_weight: float = 1.0,
alpha: float = 0.5,
weights: Optional[List[float]] = None,
) -> None:
super().__init__()
assert isinstance(scales, (list, tuple)) and len(scales) > 0 and all(isinstance(s, int) and s > 0 for s in scales), f"Expected scales to be a list of positive integers, got {scales}"
assert max_scale_weight >= min_scale_weight >= 0, f"Expected max_scale_weight to be greater than or equal to min_scale_weight, got {min_scale_weight} and {max_scale_weight}"
assert 1 > alpha > 0, f"Expected alpha to be between 0 and 1, got {alpha}"
self.min_scale_weight, self.max_scale_weight = min_scale_weight, max_scale_weight
scales = sorted(scales) # sort scales in ascending order so that the last one is the largest
weights = [min_scale_weight + (max_scale_weight - min_scale_weight) * alpha ** (math.log2(scales[-1] / s)) for s in scales] if weights is None else weights # e.g., [1, 2, 4, 8] -> [0.125, 0.25, 0.5, 1]
assert len(scales) == len(weights), f"Expected scales and weights to have the same length, got {len(scales)} and {len(weights)}"
self.scales, self.weights = scales, weights
for idx in range(len(scales)):
setattr(self, f"pool_{scales[idx]}", nn.AvgPool2d(kernel_size=scales[idx], stride=scales[idx]) if scales[idx] > 1 else nn.Identity())
setattr(self, f"weight_{scales[idx]}", weights[idx])
setattr(self, f"mae_loss_fn_{scales[idx]}", nn.L1Loss(reduction="none"))
def forward(
self,
pred_den_map: Tensor,
gt_den_map: Tensor,
) -> Tuple[Tensor, Dict]:
assert len(pred_den_map.shape) == 4, f"Expected pred_den_map to have 4 dimensions, got {len(pred_den_map.shape)}"
assert len(gt_den_map.shape) == 4, f"Expected gt_den_map to have 4 dimensions, got {len(gt_den_map.shape)}"
assert pred_den_map.shape[1] == gt_den_map.shape[1] == 1, f"Expected pred_den_map and gt_den_map to have 1 channel, got {pred_den_map.shape[1]} and {gt_den_map.shape[1]}"
assert pred_den_map.shape == gt_den_map.shape, f"Expected pred_den_map and gt_den_map to have the same shape, got {pred_den_map.shape} and {gt_den_map.shape}"
loss, loss_info = 0, {}
for idx in range(len(self.scales)):
pool = getattr(self, f"pool_{self.scales[idx]}")
weight = getattr(self, f"weight_{self.scales[idx]}")
loss_fn = getattr(self, f"mae_loss_fn_{self.scales[idx]}")
pred_den_map_pool = pool(pred_den_map)
gt_den_map_pool = pool(gt_den_map)
mae_loss_scale = loss_fn(pred_den_map_pool, gt_den_map_pool).sum(dim=(-1, -2)).mean()
loss += weight * mae_loss_scale
loss_info[f"mae_loss_{self.scales[idx]}"] = mae_loss_scale.detach()
return loss, loss_info
|