File size: 3,126 Bytes
0ecb9aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from torch import nn, Tensor
import math
from typing import List, Optional, Dict, Tuple


class MultiscaleMAE(nn.Module):
    def __init__(
        self,
        scales: List[int] = [1, 2, 4],
        min_scale_weight: float = 0.0,
        max_scale_weight: float = 1.0,
        alpha: float = 0.5,
        weights: Optional[List[float]] = None,
    ) -> None:
        super().__init__()
        assert isinstance(scales, (list, tuple)) and len(scales) > 0 and all(isinstance(s, int) and s > 0 for s in scales), f"Expected scales to be a list of positive integers, got {scales}"
        assert max_scale_weight >= min_scale_weight >= 0, f"Expected max_scale_weight to be greater than or equal to min_scale_weight, got {min_scale_weight} and {max_scale_weight}"
        assert 1 > alpha > 0, f"Expected alpha to be between 0 and 1, got {alpha}"
        self.min_scale_weight, self.max_scale_weight = min_scale_weight, max_scale_weight

        scales = sorted(scales)  # sort scales in ascending order so that the last one is the largest
        weights = [min_scale_weight + (max_scale_weight - min_scale_weight) * alpha ** (math.log2(scales[-1] / s)) for s in scales] if weights is None else weights  # e.g., [1, 2, 4, 8] -> [0.125, 0.25, 0.5, 1]

        assert len(scales) == len(weights), f"Expected scales and weights to have the same length, got {len(scales)} and {len(weights)}"
        self.scales, self.weights = scales, weights

        for idx in range(len(scales)):
            setattr(self, f"pool_{scales[idx]}", nn.AvgPool2d(kernel_size=scales[idx], stride=scales[idx]) if scales[idx] > 1 else nn.Identity())
            setattr(self, f"weight_{scales[idx]}", weights[idx])
            setattr(self, f"mae_loss_fn_{scales[idx]}", nn.L1Loss(reduction="none"))

    def forward(
        self,
        pred_den_map: Tensor,
        gt_den_map: Tensor,
    ) -> Tuple[Tensor, Dict]:
        assert len(pred_den_map.shape) == 4, f"Expected pred_den_map to have 4 dimensions, got {len(pred_den_map.shape)}"
        assert len(gt_den_map.shape) == 4, f"Expected gt_den_map to have 4 dimensions, got {len(gt_den_map.shape)}"
        assert pred_den_map.shape[1] == gt_den_map.shape[1] == 1, f"Expected pred_den_map and gt_den_map to have 1 channel, got {pred_den_map.shape[1]} and {gt_den_map.shape[1]}"
        assert pred_den_map.shape == gt_den_map.shape, f"Expected pred_den_map and gt_den_map to have the same shape, got {pred_den_map.shape} and {gt_den_map.shape}"
        
        loss, loss_info = 0, {}
        for idx in range(len(self.scales)):
            pool = getattr(self, f"pool_{self.scales[idx]}")
            weight = getattr(self, f"weight_{self.scales[idx]}")
            loss_fn = getattr(self, f"mae_loss_fn_{self.scales[idx]}")

            pred_den_map_pool = pool(pred_den_map)
            gt_den_map_pool = pool(gt_den_map)

            mae_loss_scale = loss_fn(pred_den_map_pool, gt_den_map_pool).sum(dim=(-1, -2)).mean()
            loss += weight * mae_loss_scale
            loss_info[f"mae_loss_{self.scales[idx]}"] = mae_loss_scale.detach()
            
        return loss, loss_info