Yiming-M commited on
Commit
0ecb9aa
·
verified ·
1 Parent(s): ffd9437

2025-07-31 15:53 🐣

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +75 -3
  2. configs/bin_config.json +50 -0
  3. configs/nwpu.yaml +34 -0
  4. configs/qnrf.yaml +33 -0
  5. configs/sha.yaml +33 -0
  6. configs/shb.yaml +33 -0
  7. count.py +253 -0
  8. count.sh +5 -0
  9. counts/jhu.json +425 -0
  10. counts/jhu_max.json +74 -0
  11. counts/nwpu.json +761 -0
  12. counts/nwpu_max.json +74 -0
  13. counts/qnrf.json +569 -0
  14. counts/qnrf_max.json +74 -0
  15. counts/sha.json +578 -0
  16. counts/sha_max.json +74 -0
  17. counts/shb.json +313 -0
  18. counts/shb_max.json +74 -0
  19. datasets/__init__.py +12 -0
  20. datasets/crowd.py +309 -0
  21. datasets/transforms.py +262 -0
  22. datasets/utils.py +63 -0
  23. efficiency.py +163 -0
  24. evaluate.py +84 -0
  25. losses/__init__.py +7 -0
  26. losses/bregman_pytorch.py +70 -0
  27. losses/dm_loss.py +142 -0
  28. losses/dual_loss.py +175 -0
  29. losses/loss.py +204 -0
  30. losses/multiscale_mae.py +55 -0
  31. losses/poisson_nll.py +46 -0
  32. losses/utils.py +19 -0
  33. losses/zero_inflated_poisson_nll.py +96 -0
  34. models/__init__.py +155 -0
  35. models/clip_ebc/__init__.py +7 -0
  36. models/clip_ebc/__pycache__/__init__.cpython-312.pyc +0 -0
  37. models/clip_ebc/__pycache__/convnext.cpython-312.pyc +0 -0
  38. models/clip_ebc/__pycache__/mobileclip.cpython-312.pyc +0 -0
  39. models/clip_ebc/__pycache__/model.cpython-312.pyc +0 -0
  40. models/clip_ebc/__pycache__/resnet.cpython-312.pyc +0 -0
  41. models/clip_ebc/__pycache__/utils.cpython-312.pyc +0 -0
  42. models/clip_ebc/__pycache__/vit.cpython-312.pyc +0 -0
  43. models/clip_ebc/__pycache__/vit_siglip.cpython-312.pyc +0 -0
  44. models/clip_ebc/convnext.py +199 -0
  45. models/clip_ebc/mobileclip.py +197 -0
  46. models/clip_ebc/model.py +272 -0
  47. models/clip_ebc/resnet.py +236 -0
  48. models/clip_ebc/utils.py +137 -0
  49. models/clip_ebc/vit.py +372 -0
  50. models/ebc/__init__.py +3 -0
README.md CHANGED
@@ -1,3 +1,75 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # EBC-ZIP
2
+
3
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/ebc-zip-improving-blockwise-crowd-counting/crowd-counting-on-shanghaitech-a)](https://paperswithcode.com/sota/crowd-counting-on-shanghaitech-a?p=ebc-zip-improving-blockwise-crowd-counting)
4
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/ebc-zip-improving-blockwise-crowd-counting/crowd-counting-on-shanghaitech-b)](https://paperswithcode.com/sota/crowd-counting-on-shanghaitech-b?p=ebc-zip-improving-blockwise-crowd-counting)
5
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/ebc-zip-improving-blockwise-crowd-counting/crowd-counting-on-ucf-qnrf)](https://paperswithcode.com/sota/crowd-counting-on-ucf-qnrf?p=ebc-zip-improving-blockwise-crowd-counting)
6
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/ebc-zip-improving-blockwise-crowd-counting/crowd-counting-on-nwpu-crowd-val)](https://paperswithcode.com/sota/crowd-counting-on-nwpu-crowd-val?p=ebc-zip-improving-blockwise-crowd-counting)
7
+
8
+ The official implementation of the paper [*ZIP: Scalable Crowd Counting via Zero-Inflated Poisson Modeling*](https://arxiv.org/pdf/2506.19955).
9
+
10
+ ## Reults
11
+
12
+ | **Variants** | **Size (M)** | **GFLOPS (on HD)** | **SHA (MAE)** | **SHA (RMSE)** | **SHA (NAE, %)** | **SHB (MAE)** | **SHB (RMSE)** | **SHB (NAE, %)** | **QNRF (MAE)** | **QNRF (RMSE)** | **QNRF (NAE, %)** |
13
+ |--------------|--------------|--------------------|---------------|----------------|------------------|---------------|----------------|------------------|----------------|-----------------|-------------------|
14
+ | -P (Pico) | 0.81 | 6.46 | 71.18 | 109.60 | 16.69 | 8.23 | 12.62 | 6.98 | 96.29 | 161.82 | 14.40 |
15
+ | -N (Nano) | 3.36 | 24.73 | 58.86 | 94.63 | 14.15 | 7.74 | 12.14 | 6.33 | 86.46 | 147.64 | 12.60 |
16
+ | -T (Tiny) | 10.53 | 61.39 | 56.36 | 86.09 | 13.26 | 6.67 | 9.90 | 5.52 | 76.02 | 129.40 | 11.10 |
17
+ | -S (Small) | 33.60 | 242.43 | 55.17 | 88.99 | 11.97 | 5.83 | 9.21 | 4.58 | 73.32 | 125.09 | 10.40 |
18
+ | -B (Base) | 105.60 | 800.99 | 47.81 | 75.04 | 11.06 | 5.51 | 8.63 | 4.48 | 69.46 | 121.88 | 10.18 |
19
+
20
+ ## Step 1: Install Dependencies
21
+
22
+ ```bash
23
+ pip install -r requirements.txt
24
+ ```
25
+
26
+ ## Step 2: Download Processed Datasets
27
+
28
+ - **ShanghaiTech A**: [sha.zip](https://github.com/Yiming-M/EBC-ZIP/releases/download/dataset/sha.zip)
29
+ - **ShanghaiTech B**: [shb.zip](https://github.com/Yiming-M/EBC-ZIP/releases/download/dataset/shb.zip)
30
+ - **UCF-QNRF**: [qnrf.zip](https://github.com/Yiming-M/EBC-ZIP/releases/download/dataset/qnrf.zip), [qnrf.z01](https://github.com/Yiming-M/EBC-ZIP/releases/download/dataset/qnrf.z01)
31
+ - **NWPU-Crowd**: [nwpu.zip](https://github.com/Yiming-M/EBC-ZIP/releases/download/dataset/nwpu.zip), [nwpu.z01](https://github.com/Yiming-M/EBC-ZIP/releases/download/dataset/nwpu.z01), [nwpu.z02](https://github.com/Yiming-M/EBC-ZIP/releases/download/dataset/nwpu.z02), [nwpu.z03](https://github.com/Yiming-M/EBC-ZIP/releases/download/dataset/nwpu.z03), [nwpu.z04](https://github.com/Yiming-M/EBC-ZIP/releases/download/dataset/nwpu.z04), [nwpu.z05](https://github.com/Yiming-M/EBC-ZIP/releases/download/dataset/nwpu.z05), [nwpu.z06](https://github.com/Yiming-M/EBC-ZIP/releases/download/dataset/nwpu.z06), [nwpu.z07](https://github.com/Yiming-M/EBC-ZIP/releases/download/dataset/nwpu.z07), [nwpu.z08](https://github.com/Yiming-M/EBC-ZIP/releases/download/dataset/nwpu.z08)
32
+
33
+ To unzip splitted `.zip` files, 7-Zip is recommended. You can use the following command to install 7-Zip and unzip the dataset:
34
+
35
+ ```bash
36
+ sudo apt update
37
+ sudo apt install p7zip-full
38
+
39
+ 7z x dataset.zip
40
+ ```
41
+
42
+ ## Step 3: Run Training
43
+
44
+ Add the training code to `run.sh` and execute it:
45
+
46
+ ```bash
47
+ sh run.sh
48
+ ```
49
+
50
+ If you want to use the zero-inflated loss, set either `--reg_loss` or `--aux_loss` to `zipnll`. For example, you can set `--reg_loss zipnll` to use the zero-inflated loss for regression.
51
+
52
+ You can use an auxillary loss to improve the performance. For example, you might want to use the pre-defined multi-scale MAE loss by setting `--aux_loss msmae` and `--scales 1 2 4`.
53
+
54
+ The DMCount loss can also be used together with the zero-inflated loss. For example, you can set `--reg_loss zipnll --aux_loss dmcount` to use both losses.
55
+
56
+
57
+ ## Step 4: Test the Model
58
+
59
+ Use `test.py` or `test.sh` to test the model. You can specify the dataset, weight path, input size, and other parameters.
60
+
61
+ To generate the predicted counts on NWPU-Crowd Test, you need to use `test_nwpu.py` instead.
62
+
63
+ To visualize the results, use the `notebooks/model.ipynb` notebook.
64
+
65
+ Trained weights are also provided:
66
+ - [**ShanghaiTech A**](https://github.com/Yiming-M/EBC-ZIP/releases/tag/weights_sha)
67
+ - [**ShanghaiTech B**](https://github.com/Yiming-M/EBC-ZIP/releases/tag/weights_shb)
68
+ - [**UCF-QNRF**](https://github.com/Yiming-M/EBC-ZIP/releases/tag/weights_qnrf)
69
+ - [**NWPU-Crowd**](https://github.com/Yiming-M/EBC-ZIP/releases/tag/weights_nwpu)
70
+
71
+ Make sure to use the processed datasets and the exact commands pre-defined in `test.sh` to reproduce the same results.
72
+
73
+ ## Step 5: Visualize the Results
74
+
75
+ Use the `notebooks/model.ipynb` notebook to visualize the results.
configs/bin_config.json ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "shb": {
3
+ "8": [[0, 0], [1, 1], [2, 2], [3, 3], [4, "inf"]],
4
+ "16": [[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7], [8, 8], [9, "inf"]],
5
+ "32": [
6
+ [0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7], [8, 8], [9, 9],
7
+ [10, 10], [11, 11], [12, 12], [13, 13], [14, 14],
8
+ [15, 16], [17, 18], [19, 20],
9
+ [21, 23], [24, "inf"]
10
+ ]
11
+ },
12
+ "sha": {
13
+ "8": [[0, 0], [1, 1], [2, 2], [3, 3], [4, "inf"]],
14
+ "16": [
15
+ [0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7], [8, 8], [9, 9],
16
+ [10, 10], [11, 12], [13, 14], [15, "inf"]
17
+ ],
18
+ "32": [
19
+ [0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7], [8, 8], [9, 9],
20
+ [10, 10], [11, 11], [12, 12], [13, 13], [14, 14], [15, 15], [16, 16], [17, 17], [18, 18], [19, 19],
21
+ [20, 21], [22, 23], [24, 25], [26, 27], [28, 29],
22
+ [30, 32], [33, 35], [36, 38], [39, 41],
23
+ [42, 45], [46, "inf"]
24
+ ]
25
+ },
26
+ "qnrf": {
27
+ "8": [[0, 0], [1, 1], [2, 2], [3, 3], [4, "inf"]],
28
+ "16": [
29
+ [0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7], [8, 8], [9, 9],
30
+ [10, 10], [11, 12], [13, "inf"]
31
+ ],
32
+ "32": [
33
+ [0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7], [8, 8], [9, 9],
34
+ [10, 10], [11, 12], [13, 14], [15, 16], [17, 18], [19, 20],
35
+ [21, 23], [24, 26], [27, 29], [30, 33], [34, "inf"]
36
+ ]
37
+ },
38
+ "nwpu": {
39
+ "8": [[0, 0], [1, 1], [2, 2], [3, 3], [4, "inf"]],
40
+ "16": [
41
+ [0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7], [8, 8], [9, 9],
42
+ [10, "inf"]
43
+ ],
44
+ "32": [
45
+ [0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7], [8, 8], [9, 9],
46
+ [10, 11], [12, 13], [14, 15], [16, 17], [18, 19],
47
+ [20, 22], [23, 25], [26, 28], [29, "inf"]
48
+ ]
49
+ }
50
+ }
configs/nwpu.yaml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ metadata:
2
+ name: NWPU
3
+ description: Training configuration on the NWPU dataset.
4
+
5
+ input_size: 672
6
+ block_size: 16
7
+ batch_size: 8
8
+ num_crops: 1
9
+
10
+ aug_min_scale: 0.75
11
+ aug_max_scale: 2.0
12
+ aug_brightness: 0.2
13
+ aug_contrast: 0.2
14
+ aug_saturation: 0.15
15
+ aug_hue: 0.0
16
+ aug_kernel_size: 5
17
+ aug_blur_prob: 0.2
18
+ aug_saltiness: 0.001
19
+ aug_spiciness: 0.001
20
+
21
+ lr: 0.0001
22
+ vpt_lr: 0.0001
23
+ adapter_lr: 0.0001
24
+ lora_lr: 0.0001
25
+ backbone_lr: 0.0001
26
+
27
+ weight_decay: 0.0001
28
+ vpt_weight_decay: 0.0001
29
+ adapter_weight_decay: 0.0001
30
+ lora_weight_decay: 0.0001
31
+ backbone_weight_decay: 0.0001
32
+
33
+ eval_freq: 1.0
34
+ eval_start: 100
configs/qnrf.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ metadata:
2
+ name: qnrf
3
+ description: Training configuration on the UCF-QNRF dataset.
4
+
5
+ input_size: 672
6
+ block_size: 32
7
+ batch_size: 8
8
+ num_crops: 1
9
+
10
+ aug_min_scale: 0.75
11
+ aug_max_scale: 2.0
12
+ aug_brightness: 0.15
13
+ aug_contrast: 0.15
14
+ aug_saturation: 0.1
15
+ aug_hue: 0.0
16
+ aug_blur_prob: 0.0
17
+ aug_saltiness: 0.001
18
+ aug_spiciness: 0.001
19
+
20
+ lr: 0.0001
21
+ vpt_lr: 0.0001
22
+ adapter_lr: 0.0001
23
+ lora_lr: 0.0001
24
+ backbone_lr: 0.0001
25
+
26
+ weight_decay: 0.0001
27
+ vpt_weight_decay: 0.0001
28
+ adapter_weight_decay: 0.0001
29
+ lora_weight_decay: 0.0001
30
+ backbone_weight_decay: 0.0001
31
+
32
+ eval_freq: 0.5
33
+ eval_start: 150
configs/sha.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ metadata:
2
+ name: sha
3
+ description: Training configuration on the ShanghaiTech A dataset.
4
+
5
+ input_size: 448
6
+ block_size: 16
7
+ batch_size: 8
8
+ num_crops: 1
9
+
10
+ aug_min_scale: 0.75
11
+ aug_max_scale: 2.0
12
+ aug_brightness: 0.15
13
+ aug_contrast: 0.15
14
+ aug_saturation: 0.1
15
+ aug_hue: 0.0
16
+ aug_blur_prob: 0.0
17
+ aug_saltiness: 0.001
18
+ aug_spiciness: 0.001
19
+
20
+ lr: 0.0001
21
+ vpt_lr: 0.0001
22
+ adapter_lr: 0.0001
23
+ lora_lr: 0.0001
24
+ backbone_lr: 0.0001
25
+
26
+ weight_decay: 0.0001
27
+ vpt_weight_decay: 0.0001
28
+ adapter_weight_decay: 0.0001
29
+ lora_weight_decay: 0.0001
30
+ backbone_weight_decay: 0.0001
31
+
32
+ eval_freq: 0.25
33
+ eval_start: 100
configs/shb.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ metadata:
2
+ name: shb
3
+ description: Training configuration on the ShanghaiTech B dataset.
4
+
5
+ input_size: 448
6
+ block_size: 16
7
+ batch_size: 8
8
+ num_crops: 1
9
+
10
+ aug_min_scale: 0.75
11
+ aug_max_scale: 2.5
12
+ aug_brightness: 0.15
13
+ aug_contrast: 0.15
14
+ aug_saturation: 0.1
15
+ aug_hue: 0.0
16
+ aug_blur_prob: 0.0
17
+ aug_saltiness: 0.001
18
+ aug_spiciness: 0.001
19
+
20
+ lr: 0.0001
21
+ vpt_lr: 0.0001
22
+ adapter_lr: 0.0001
23
+ lora_lr: 0.0001
24
+ backbone_lr: 0.0001
25
+
26
+ weight_decay: 0.0001
27
+ vpt_weight_decay: 0.0001
28
+ adapter_weight_decay: 0.0001
29
+ lora_weight_decay: 0.0001
30
+ backbone_weight_decay: 0.0001
31
+
32
+ eval_freq: 0.25
33
+ eval_start: 150
count.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import numpy as np
4
+ import os, json
5
+ from tqdm import tqdm
6
+ from argparse import ArgumentParser
7
+ from typing import Dict
8
+
9
+ import datasets
10
+
11
+
12
+ class SumPool2d(nn.Module):
13
+ def __init__(self, kernel_size: int, stride: int):
14
+ super(SumPool2d, self).__init__()
15
+ self.kernel_size = kernel_size
16
+ self.stride = stride
17
+ self.sum_pool = nn.AvgPool2d(kernel_size, stride, divisor_override=1)
18
+
19
+ def forward(self, x):
20
+ return self.sum_pool(x)
21
+
22
+
23
+ def _update_dict(d: Dict, keys: np.ndarray, values: np.ndarray) -> Dict:
24
+ keys = keys.tolist() if isinstance(keys, np.ndarray) else keys
25
+ values = values.tolist() if isinstance(values, np.ndarray) else values
26
+ for k, v in zip(keys, values):
27
+ d[k] = d.get(k, 0) + v
28
+
29
+ return d
30
+
31
+
32
+ def _get_counts(
33
+ dataset_name: str,
34
+ device: torch.device,
35
+ ) -> None:
36
+ filter_4 = SumPool2d(4, 1).to(device)
37
+ filter_7 = SumPool2d(7, 1).to(device)
38
+ filter_8 = SumPool2d(8, 1).to(device)
39
+ filter_14 = SumPool2d(14, 1).to(device)
40
+ filter_16 = SumPool2d(16, 1).to(device)
41
+ filter_28 = SumPool2d(28, 1).to(device)
42
+ filter_32 = SumPool2d(32, 1).to(device)
43
+ filter_56 = SumPool2d(56, 1).to(device)
44
+ filter_64 = SumPool2d(64, 1).to(device)
45
+ counts_1, counts_4, counts_7, counts_8 = {}, {}, {}, {}
46
+ counts_14, counts_16 = {}, {}
47
+ counts_28, counts_32 = {}, {}
48
+ counts_56, counts_64 = {}, {}
49
+
50
+ max_counts_4 = {"max": 0., "name": None, "x": None, "y": None}
51
+ max_counts_7 = {"max": 0., "name": None, "x": None, "y": None}
52
+ max_counts_8 = {"max": 0., "name": None, "x": None, "y": None}
53
+ max_counts_14 = {"max": 0., "name": None, "x": None, "y": None}
54
+ max_counts_16 = {"max": 0., "name": None, "x": None, "y": None}
55
+ max_counts_28 = {"max": 0., "name": None, "x": None, "y": None}
56
+ max_counts_32 = {"max": 0., "name": None, "x": None, "y": None}
57
+ max_counts_56 = {"max": 0., "name": None, "x": None, "y": None}
58
+ max_counts_64 = {"max": 0., "name": None, "x": None, "y": None}
59
+
60
+ counts_dir = os.path.join(os.getcwd(), "counts")
61
+ os.makedirs(counts_dir, exist_ok=True)
62
+
63
+ dataset = datasets.Crowd(dataset=dataset_name, split="train", transforms=None, return_filename=True)
64
+ print(f"Counting {dataset_name} dataset")
65
+
66
+ for i in tqdm(range(len(dataset))):
67
+ _, _, density, img_name = dataset[i]
68
+ density_np = density.cpu().numpy().astype(int)
69
+ uniques_, counts_ = np.unique(density_np, return_counts=True)
70
+ counts_1 = _update_dict(counts_1, uniques_, counts_)
71
+
72
+ density = density.to(device) # Add batch dimension
73
+ window_4, window_7, window_8 = filter_4(density), filter_7(density), filter_8(density)
74
+ window_14, window_16 = filter_14(density), filter_16(density)
75
+ window_28, window_32 = filter_28(density), filter_32(density)
76
+ window_56, window_64 = filter_56(density), filter_64(density)
77
+
78
+ window_4, window_7, window_8 = torch.round(window_4).int(), torch.round(window_7).int(), torch.round(window_8).int()
79
+ window_14, window_16 = torch.round(window_14).int(), torch.round(window_16).int()
80
+ window_28, window_32 = torch.round(window_28).int(), torch.round(window_32).int()
81
+ window_56, window_64 = torch.round(window_56).int(), torch.round(window_64).int()
82
+
83
+ window_4, window_7, window_8 = torch.squeeze(window_4), torch.squeeze(window_7), torch.squeeze(window_8)
84
+ window_14, window_16 = torch.squeeze(window_14), torch.squeeze(window_16)
85
+ window_28, window_32 = torch.squeeze(window_28), torch.squeeze(window_32)
86
+ window_56, window_64 = torch.squeeze(window_56), torch.squeeze(window_64)
87
+
88
+ if window_4.max().item() > max_counts_4["max"]:
89
+ max_counts_4["max"] = window_4.max().item()
90
+ max_counts_4["name"] = img_name
91
+ x, y = torch.where(window_4 == window_4.max())
92
+ x, y = x[0].item(), y[0].item()
93
+ max_counts_4["x"] = x
94
+ max_counts_4["y"] = y
95
+
96
+ if window_7.max().item() > max_counts_7["max"]:
97
+ max_counts_7["max"] = window_7.max().item()
98
+ max_counts_7["name"] = img_name
99
+ x, y = torch.where(window_7 == window_7.max())
100
+ x, y = x[0].item(), y[0].item()
101
+ max_counts_7["x"] = x
102
+ max_counts_7["y"] = y
103
+
104
+ if window_8.max().item() > max_counts_8["max"]:
105
+ max_counts_8["max"] = window_8.max().item()
106
+ max_counts_8["name"] = img_name
107
+ x, y = torch.where(window_8 == window_8.max())
108
+ x, y = x[0].item(), y[0].item()
109
+ max_counts_8["x"] = x
110
+ max_counts_8["y"] = y
111
+
112
+ if window_14.max().item() > max_counts_14["max"]:
113
+ max_counts_14["max"] = window_14.max().item()
114
+ max_counts_14["name"] = img_name
115
+ x, y = torch.where(window_14 == window_14.max())
116
+ x, y = x[0].item(), y[0].item()
117
+ max_counts_14["x"] = x
118
+ max_counts_14["y"] = y
119
+
120
+ if window_16.max().item() > max_counts_16["max"]:
121
+ max_counts_16["max"] = window_16.max().item()
122
+ max_counts_16["name"] = img_name
123
+ x, y = torch.where(window_16 == window_16.max())
124
+ x, y = x[0].item(), y[0].item()
125
+ max_counts_16["x"] = x
126
+ max_counts_16["y"] = y
127
+
128
+ if window_28.max().item() > max_counts_28["max"]:
129
+ max_counts_28["max"] = window_28.max().item()
130
+ max_counts_28["name"] = img_name
131
+ x, y = torch.where(window_28 == window_28.max())
132
+ x, y = x[0].item(), y[0].item()
133
+ max_counts_28["x"] = x
134
+ max_counts_28["y"] = y
135
+
136
+ if window_32.max().item() > max_counts_32["max"]:
137
+ max_counts_32["max"] = window_32.max().item()
138
+ max_counts_32["name"] = img_name
139
+ x, y = torch.where(window_32 == window_32.max())
140
+ x, y = x[0].item(), y[0].item()
141
+ max_counts_32["x"] = x
142
+ max_counts_32["y"] = y
143
+
144
+ if window_56.max().item() > max_counts_56["max"]:
145
+ max_counts_56["max"] = window_56.max().item()
146
+ max_counts_56["name"] = img_name
147
+ x, y = torch.where(window_56 == window_56.max())
148
+ x, y = x[0].item(), y[0].item()
149
+ max_counts_56["x"] = x
150
+ max_counts_56["y"] = y
151
+
152
+ if window_64.max().item() > max_counts_64["max"]:
153
+ max_counts_64["max"] = window_64.max().item()
154
+ max_counts_64["name"] = img_name
155
+ x, y = torch.where(window_64 == window_64.max())
156
+ x, y = x[0].item(), y[0].item()
157
+ max_counts_64["x"] = x
158
+ max_counts_64["y"] = y
159
+
160
+ window_4 = window_4.view(-1).cpu().numpy().astype(int)
161
+ window_7 = window_7.view(-1).cpu().numpy().astype(int)
162
+ window_8 = window_8.view(-1).cpu().numpy().astype(int)
163
+ window_14 = window_14.view(-1).cpu().numpy().astype(int)
164
+ window_16 = window_16.view(-1).cpu().numpy().astype(int)
165
+ window_28 = window_28.view(-1).cpu().numpy().astype(int)
166
+ window_32 = window_32.view(-1).cpu().numpy().astype(int)
167
+ window_56 = window_56.view(-1).cpu().numpy().astype(int)
168
+ window_64 = window_64.view(-1).cpu().numpy().astype(int)
169
+ #.view(-1).cpu().numpy().astype(int)
170
+
171
+ uniques_, counts_ = np.unique(window_4, return_counts=True)
172
+ counts_4 = _update_dict(counts_4, uniques_, counts_)
173
+
174
+ uniques_, counts_ = np.unique(window_7, return_counts=True)
175
+ counts_7 = _update_dict(counts_7, uniques_, counts_)
176
+
177
+ uniques_, counts_ = np.unique(window_8, return_counts=True)
178
+ counts_8 = _update_dict(counts_8, uniques_, counts_)
179
+
180
+ uniques_, counts_ = np.unique(window_14, return_counts=True)
181
+ counts_14 = _update_dict(counts_14, uniques_, counts_)
182
+
183
+ uniques_, counts_ = np.unique(window_16, return_counts=True)
184
+ counts_16 = _update_dict(counts_16, uniques_, counts_)
185
+
186
+ uniques_, counts_ = np.unique(window_28, return_counts=True)
187
+ counts_28 = _update_dict(counts_28, uniques_, counts_)
188
+
189
+ uniques_, counts_ = np.unique(window_32, return_counts=True)
190
+ counts_32 = _update_dict(counts_32, uniques_, counts_)
191
+
192
+ uniques_, counts_ = np.unique(window_56, return_counts=True)
193
+ counts_56 = _update_dict(counts_56, uniques_, counts_)
194
+
195
+ uniques_, counts_ = np.unique(window_64, return_counts=True)
196
+ counts_64 = _update_dict(counts_64, uniques_, counts_)
197
+
198
+ counts = {
199
+ 1: counts_1,
200
+ 4: counts_4,
201
+ 7: counts_7,
202
+ 8: counts_8,
203
+ 14: counts_14,
204
+ 16: counts_16,
205
+ 28: counts_28,
206
+ 32: counts_32,
207
+ 56: counts_56,
208
+ 64: counts_64
209
+ }
210
+
211
+ max_counts = {
212
+ 4: max_counts_4,
213
+ 7: max_counts_7,
214
+ 8: max_counts_8,
215
+ 14: max_counts_14,
216
+ 16: max_counts_16,
217
+ 28: max_counts_28,
218
+ 32: max_counts_32,
219
+ 56: max_counts_56,
220
+ 64: max_counts_64
221
+ }
222
+
223
+ with open(os.path.join(counts_dir, f"{dataset_name}.json"), "w") as f:
224
+ json.dump(counts, f)
225
+
226
+ with open(os.path.join(counts_dir, f"{dataset_name}_max.json"), "w") as f:
227
+ json.dump(max_counts, f)
228
+
229
+
230
+ def parse_args():
231
+ parser = ArgumentParser(description="Get local counts of the dataset")
232
+ parser.add_argument(
233
+ "--dataset",
234
+ type=str,
235
+ choices=["nwpu", "ucf_qnrf", "shanghaitech_a", "shanghaitech_b"],
236
+ required=True,
237
+ help="The dataset to use."
238
+ )
239
+ parser.add_argument(
240
+ "--device",
241
+ type=str,
242
+ default="cuda",
243
+ help="The device to use."
244
+ )
245
+ args = parser.parse_args()
246
+ return args
247
+
248
+
249
+ if __name__ == "__main__":
250
+ args = parse_args()
251
+ args.dataset = datasets.standardize_dataset_name(args.dataset)
252
+ args.device = torch.device(args.device)
253
+ _get_counts(args.dataset, args.device)
count.sh ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #!/bin/sh
2
+ python count.py --dataset shanghaitech_a --device cuda:0
3
+ python count.py --dataset shanghaitech_b --device cuda:0
4
+ python count.py --dataset nwpu --device cuda:0
5
+ python count.py --dataset ucf_qnrf --device cuda:0
counts/jhu.json ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "1": {
3
+ "0": 5442129077,
4
+ "1": 844619
5
+ },
6
+ "4": {
7
+ "0": 5411259934,
8
+ "1": 13337323,
9
+ "2": 75154,
10
+ "3": 1725,
11
+ "4": 40
12
+ },
13
+ "7": {
14
+ "0": 5366145063,
15
+ "1": 39388535,
16
+ "2": 807008,
17
+ "3": 68635,
18
+ "4": 5975,
19
+ "5": 318,
20
+ "6": 17,
21
+ "7": 1
22
+ },
23
+ "8": {
24
+ "0": 5348298656,
25
+ "1": 50463806,
26
+ "2": 1400221,
27
+ "3": 154835,
28
+ "4": 19051,
29
+ "5": 1731,
30
+ "6": 121,
31
+ "7": 11
32
+ },
33
+ "14": {
34
+ "0": 5220148724,
35
+ "1": 129080801,
36
+ "2": 11196548,
37
+ "3": 2346703,
38
+ "4": 762426,
39
+ "5": 281109,
40
+ "6": 104707,
41
+ "7": 35659,
42
+ "8": 10533,
43
+ "9": 2989,
44
+ "10": 724,
45
+ "11": 196,
46
+ "12": 16,
47
+ "13": 1
48
+ },
49
+ "16": {
50
+ "0": 5172190839,
51
+ "1": 156244565,
52
+ "2": 17047061,
53
+ "3": 3987628,
54
+ "4": 1373739,
55
+ "5": 580316,
56
+ "6": 265393,
57
+ "7": 117895,
58
+ "8": 48278,
59
+ "9": 18825,
60
+ "10": 6835,
61
+ "11": 2535,
62
+ "12": 909,
63
+ "13": 209,
64
+ "14": 27,
65
+ "15": 2
66
+ },
67
+ "28": {
68
+ "0": 4868806093,
69
+ "1": 296210451,
70
+ "2": 64607415,
71
+ "3": 23796771,
72
+ "4": 11220229,
73
+ "5": 5869184,
74
+ "6": 3249319,
75
+ "7": 1854162,
76
+ "8": 1153843,
77
+ "9": 778472,
78
+ "10": 561910,
79
+ "11": 425259,
80
+ "12": 332715,
81
+ "13": 255032,
82
+ "14": 191332,
83
+ "15": 137704,
84
+ "16": 95475,
85
+ "17": 64842,
86
+ "18": 43528,
87
+ "19": 29738,
88
+ "20": 20028,
89
+ "21": 13687,
90
+ "22": 9609,
91
+ "23": 7228,
92
+ "24": 4847,
93
+ "25": 3457,
94
+ "26": 2563,
95
+ "27": 1831,
96
+ "28": 1349,
97
+ "29": 917,
98
+ "30": 589,
99
+ "31": 360,
100
+ "32": 213,
101
+ "33": 94,
102
+ "34": 22,
103
+ "35": 4
104
+ },
105
+ "32": {
106
+ "0": 4768229484,
107
+ "1": 332242168,
108
+ "2": 81810540,
109
+ "3": 32189657,
110
+ "4": 16022983,
111
+ "5": 8984314,
112
+ "6": 5419164,
113
+ "7": 3339453,
114
+ "8": 2097270,
115
+ "9": 1359271,
116
+ "10": 927341,
117
+ "11": 673849,
118
+ "12": 519302,
119
+ "13": 413081,
120
+ "14": 339682,
121
+ "15": 282493,
122
+ "16": 235154,
123
+ "17": 189365,
124
+ "18": 147778,
125
+ "19": 111779,
126
+ "20": 83938,
127
+ "21": 61440,
128
+ "22": 44843,
129
+ "23": 32312,
130
+ "24": 23514,
131
+ "25": 17003,
132
+ "26": 12718,
133
+ "27": 9671,
134
+ "28": 7115,
135
+ "29": 5853,
136
+ "30": 4515,
137
+ "31": 3342,
138
+ "32": 2525,
139
+ "33": 1880,
140
+ "34": 1522,
141
+ "35": 1199,
142
+ "36": 1034,
143
+ "37": 733,
144
+ "38": 561,
145
+ "39": 400,
146
+ "40": 287,
147
+ "41": 134,
148
+ "42": 62,
149
+ "43": 19,
150
+ "44": 4
151
+ },
152
+ "56": {
153
+ "0": 4222181888,
154
+ "1": 453337627,
155
+ "2": 170668322,
156
+ "3": 85503361,
157
+ "4": 50077828,
158
+ "5": 32125898,
159
+ "6": 22063372,
160
+ "7": 15687182,
161
+ "8": 11585957,
162
+ "9": 8807535,
163
+ "10": 6902417,
164
+ "11": 5494688,
165
+ "12": 4464497,
166
+ "13": 3672794,
167
+ "14": 3059884,
168
+ "15": 2569337,
169
+ "16": 2181015,
170
+ "17": 1848256,
171
+ "18": 1568914,
172
+ "19": 1327646,
173
+ "20": 1110617,
174
+ "21": 923381,
175
+ "22": 763225,
176
+ "23": 634769,
177
+ "24": 533036,
178
+ "25": 446198,
179
+ "26": 375536,
180
+ "27": 319752,
181
+ "28": 277970,
182
+ "29": 246034,
183
+ "30": 221081,
184
+ "31": 200820,
185
+ "32": 185527,
186
+ "33": 172457,
187
+ "34": 163190,
188
+ "35": 155461,
189
+ "36": 149548,
190
+ "37": 144236,
191
+ "38": 139882,
192
+ "39": 134703,
193
+ "40": 129346,
194
+ "41": 123503,
195
+ "42": 117688,
196
+ "43": 109973,
197
+ "44": 101970,
198
+ "45": 94300,
199
+ "46": 87095,
200
+ "47": 80710,
201
+ "48": 73843,
202
+ "49": 66773,
203
+ "50": 61099,
204
+ "51": 55590,
205
+ "52": 48984,
206
+ "53": 43741,
207
+ "54": 38838,
208
+ "55": 34038,
209
+ "56": 30826,
210
+ "57": 28088,
211
+ "58": 25668,
212
+ "59": 23430,
213
+ "60": 21750,
214
+ "61": 18902,
215
+ "62": 16508,
216
+ "63": 14272,
217
+ "64": 12549,
218
+ "65": 10596,
219
+ "66": 9228,
220
+ "67": 8081,
221
+ "68": 7185,
222
+ "69": 6284,
223
+ "70": 5698,
224
+ "71": 5124,
225
+ "72": 4488,
226
+ "73": 3761,
227
+ "74": 3171,
228
+ "75": 2908,
229
+ "76": 2554,
230
+ "77": 2211,
231
+ "78": 1956,
232
+ "79": 1784,
233
+ "80": 1529,
234
+ "81": 1317,
235
+ "82": 1189,
236
+ "83": 1136,
237
+ "84": 1086,
238
+ "85": 1012,
239
+ "86": 890,
240
+ "87": 914,
241
+ "88": 895,
242
+ "89": 832,
243
+ "90": 698,
244
+ "91": 607,
245
+ "92": 546,
246
+ "93": 526,
247
+ "94": 411,
248
+ "95": 386,
249
+ "96": 372,
250
+ "97": 415,
251
+ "98": 428,
252
+ "99": 487,
253
+ "100": 506,
254
+ "101": 549,
255
+ "102": 453,
256
+ "103": 475,
257
+ "104": 432,
258
+ "105": 391,
259
+ "106": 349,
260
+ "107": 307,
261
+ "108": 236,
262
+ "109": 183,
263
+ "110": 162,
264
+ "111": 128,
265
+ "112": 97,
266
+ "113": 48,
267
+ "114": 34,
268
+ "115": 14,
269
+ "116": 10,
270
+ "117": 7,
271
+ "118": 3,
272
+ "119": 1,
273
+ "120": 1
274
+ },
275
+ "64": {
276
+ "0": 4064136120,
277
+ "1": 469518405,
278
+ "2": 190549696,
279
+ "3": 101410734,
280
+ "4": 61441010,
281
+ "5": 40341860,
282
+ "6": 28363124,
283
+ "7": 20699526,
284
+ "8": 15647286,
285
+ "9": 12025617,
286
+ "10": 9421729,
287
+ "11": 7602900,
288
+ "12": 6244037,
289
+ "13": 5183786,
290
+ "14": 4355369,
291
+ "15": 3680829,
292
+ "16": 3145664,
293
+ "17": 2707446,
294
+ "18": 2348723,
295
+ "19": 2053730,
296
+ "20": 1802355,
297
+ "21": 1584446,
298
+ "22": 1402996,
299
+ "23": 1243258,
300
+ "24": 1087095,
301
+ "25": 947714,
302
+ "26": 818905,
303
+ "27": 707951,
304
+ "28": 615285,
305
+ "29": 531101,
306
+ "30": 459448,
307
+ "31": 397639,
308
+ "32": 343028,
309
+ "33": 295704,
310
+ "34": 259036,
311
+ "35": 229935,
312
+ "36": 207856,
313
+ "37": 189177,
314
+ "38": 173617,
315
+ "39": 158969,
316
+ "40": 147768,
317
+ "41": 139725,
318
+ "42": 132730,
319
+ "43": 127226,
320
+ "44": 122630,
321
+ "45": 118232,
322
+ "46": 115769,
323
+ "47": 114576,
324
+ "48": 111942,
325
+ "49": 107720,
326
+ "50": 105347,
327
+ "51": 101643,
328
+ "52": 98838,
329
+ "53": 96240,
330
+ "54": 91117,
331
+ "55": 87247,
332
+ "56": 82358,
333
+ "57": 77480,
334
+ "58": 72990,
335
+ "59": 68837,
336
+ "60": 65050,
337
+ "61": 61515,
338
+ "62": 57758,
339
+ "63": 53659,
340
+ "64": 50371,
341
+ "65": 45903,
342
+ "66": 42190,
343
+ "67": 39241,
344
+ "68": 35555,
345
+ "69": 32655,
346
+ "70": 29239,
347
+ "71": 26825,
348
+ "72": 24122,
349
+ "73": 22333,
350
+ "74": 21327,
351
+ "75": 19766,
352
+ "76": 18539,
353
+ "77": 16797,
354
+ "78": 15217,
355
+ "79": 13961,
356
+ "80": 12377,
357
+ "81": 11299,
358
+ "82": 9960,
359
+ "83": 8982,
360
+ "84": 7921,
361
+ "85": 7244,
362
+ "86": 6267,
363
+ "87": 5707,
364
+ "88": 5185,
365
+ "89": 4541,
366
+ "90": 4292,
367
+ "91": 3572,
368
+ "92": 3041,
369
+ "93": 2757,
370
+ "94": 2416,
371
+ "95": 2182,
372
+ "96": 1973,
373
+ "97": 1646,
374
+ "98": 1472,
375
+ "99": 1468,
376
+ "100": 1411,
377
+ "101": 1402,
378
+ "102": 1289,
379
+ "103": 1163,
380
+ "104": 983,
381
+ "105": 838,
382
+ "106": 777,
383
+ "107": 744,
384
+ "108": 689,
385
+ "109": 651,
386
+ "110": 651,
387
+ "111": 586,
388
+ "112": 523,
389
+ "113": 508,
390
+ "114": 464,
391
+ "115": 446,
392
+ "116": 428,
393
+ "117": 423,
394
+ "118": 390,
395
+ "119": 417,
396
+ "120": 363,
397
+ "121": 317,
398
+ "122": 316,
399
+ "123": 339,
400
+ "124": 340,
401
+ "125": 372,
402
+ "126": 372,
403
+ "127": 339,
404
+ "128": 403,
405
+ "129": 405,
406
+ "130": 428,
407
+ "131": 406,
408
+ "132": 409,
409
+ "133": 419,
410
+ "134": 396,
411
+ "135": 311,
412
+ "136": 288,
413
+ "137": 243,
414
+ "138": 195,
415
+ "139": 150,
416
+ "140": 158,
417
+ "141": 114,
418
+ "142": 105,
419
+ "143": 67,
420
+ "144": 33,
421
+ "145": 11,
422
+ "146": 7,
423
+ "147": 1
424
+ }
425
+ }
counts/jhu_max.json ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "4": {
3
+ "max": 4,
4
+ "name": [
5
+ "0050.jpg"
6
+ ],
7
+ "x": 672,
8
+ "y": 1315
9
+ },
10
+ "7": {
11
+ "max": 7,
12
+ "name": [
13
+ "0154.jpg"
14
+ ],
15
+ "x": 338,
16
+ "y": 1337
17
+ },
18
+ "8": {
19
+ "max": 7,
20
+ "name": [
21
+ "0144.jpg"
22
+ ],
23
+ "x": 639,
24
+ "y": 943
25
+ },
26
+ "14": {
27
+ "max": 13,
28
+ "name": [
29
+ "1162.jpg"
30
+ ],
31
+ "x": 604,
32
+ "y": 702
33
+ },
34
+ "16": {
35
+ "max": 15,
36
+ "name": [
37
+ "0193.jpg"
38
+ ],
39
+ "x": 593,
40
+ "y": 286
41
+ },
42
+ "28": {
43
+ "max": 35,
44
+ "name": [
45
+ "1162.jpg"
46
+ ],
47
+ "x": 578,
48
+ "y": 706
49
+ },
50
+ "32": {
51
+ "max": 44,
52
+ "name": [
53
+ "0193.jpg"
54
+ ],
55
+ "x": 596,
56
+ "y": 263
57
+ },
58
+ "56": {
59
+ "max": 120,
60
+ "name": [
61
+ "1162.jpg"
62
+ ],
63
+ "x": 562,
64
+ "y": 671
65
+ },
66
+ "64": {
67
+ "max": 147,
68
+ "name": [
69
+ "1162.jpg"
70
+ ],
71
+ "x": 562,
72
+ "y": 663
73
+ }
74
+ }
counts/nwpu.json ADDED
@@ -0,0 +1,761 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "1": {
3
+ "0": 14667500579,
4
+ "1": 1291229
5
+ },
6
+ "4": {
7
+ "0": 14607991573,
8
+ "1": 20424516,
9
+ "2": 101146,
10
+ "3": 3581,
11
+ "4": 173,
12
+ "5": 15,
13
+ "6": 1
14
+ },
15
+ "7": {
16
+ "0": 14527228244,
17
+ "1": 59342625,
18
+ "2": 1508064,
19
+ "3": 181138,
20
+ "4": 35423,
21
+ "5": 8206,
22
+ "6": 1925,
23
+ "7": 424,
24
+ "8": 92,
25
+ "9": 19,
26
+ "10": 4
27
+ },
28
+ "8": {
29
+ "0": 14496291716,
30
+ "1": 75535689,
31
+ "2": 2593492,
32
+ "3": 373522,
33
+ "4": 85180,
34
+ "5": 23129,
35
+ "6": 7605,
36
+ "7": 2404,
37
+ "8": 694,
38
+ "9": 170,
39
+ "10": 45,
40
+ "11": 7
41
+ },
42
+ "14": {
43
+ "0": 14280365725,
44
+ "1": 189868793,
45
+ "2": 17508005,
46
+ "3": 4140432,
47
+ "4": 1496968,
48
+ "5": 646243,
49
+ "6": 308292,
50
+ "7": 154512,
51
+ "8": 80925,
52
+ "9": 45696,
53
+ "10": 26811,
54
+ "11": 16841,
55
+ "12": 10489,
56
+ "13": 6798,
57
+ "14": 4437,
58
+ "15": 3038,
59
+ "16": 2097,
60
+ "17": 1426,
61
+ "18": 850,
62
+ "19": 434,
63
+ "20": 198,
64
+ "21": 105,
65
+ "22": 36,
66
+ "23": 14
67
+ },
68
+ "16": {
69
+ "0": 14200293041,
70
+ "1": 230337258,
71
+ "2": 25716807,
72
+ "3": 6591144,
73
+ "4": 2496616,
74
+ "5": 1151263,
75
+ "6": 597759,
76
+ "7": 328222,
77
+ "8": 186538,
78
+ "9": 107834,
79
+ "10": 64201,
80
+ "11": 40386,
81
+ "12": 26336,
82
+ "13": 17791,
83
+ "14": 12514,
84
+ "15": 8477,
85
+ "16": 6021,
86
+ "17": 4371,
87
+ "18": 3322,
88
+ "19": 2369,
89
+ "20": 1800,
90
+ "21": 1260,
91
+ "22": 892,
92
+ "23": 581,
93
+ "24": 317,
94
+ "25": 166,
95
+ "26": 88,
96
+ "27": 32,
97
+ "28": 5,
98
+ "29": 2
99
+ },
100
+ "28": {
101
+ "0": 13684329722,
102
+ "1": 456956241,
103
+ "2": 91566961,
104
+ "3": 34512257,
105
+ "4": 16402331,
106
+ "5": 8518065,
107
+ "6": 4898436,
108
+ "7": 3032957,
109
+ "8": 2020921,
110
+ "9": 1422203,
111
+ "10": 1041284,
112
+ "11": 785822,
113
+ "12": 600472,
114
+ "13": 463060,
115
+ "14": 356398,
116
+ "15": 278057,
117
+ "16": 220282,
118
+ "17": 175747,
119
+ "18": 141679,
120
+ "19": 115020,
121
+ "20": 92598,
122
+ "21": 75190,
123
+ "22": 61616,
124
+ "23": 50395,
125
+ "24": 40763,
126
+ "25": 33009,
127
+ "26": 26142,
128
+ "27": 21024,
129
+ "28": 16921,
130
+ "29": 14076,
131
+ "30": 11489,
132
+ "31": 10146,
133
+ "32": 8692,
134
+ "33": 7935,
135
+ "34": 7289,
136
+ "35": 6638,
137
+ "36": 5728,
138
+ "37": 5150,
139
+ "38": 4441,
140
+ "39": 3978,
141
+ "40": 3510,
142
+ "41": 3071,
143
+ "42": 2914,
144
+ "43": 2538,
145
+ "44": 2234,
146
+ "45": 1886,
147
+ "46": 1685,
148
+ "47": 1411,
149
+ "48": 1205,
150
+ "49": 1020,
151
+ "50": 817,
152
+ "51": 754,
153
+ "52": 696,
154
+ "53": 585,
155
+ "54": 540,
156
+ "55": 512,
157
+ "56": 444,
158
+ "57": 426,
159
+ "58": 364,
160
+ "59": 257,
161
+ "60": 212,
162
+ "61": 197,
163
+ "62": 157,
164
+ "63": 133,
165
+ "64": 108,
166
+ "65": 83,
167
+ "66": 95,
168
+ "67": 69,
169
+ "68": 64,
170
+ "69": 35,
171
+ "70": 21,
172
+ "71": 12,
173
+ "72": 9,
174
+ "73": 8,
175
+ "74": 3,
176
+ "75": 3
177
+ },
178
+ "32": {
179
+ "0": 13507181488,
180
+ "1": 523684788,
181
+ "2": 115677502,
182
+ "3": 46067053,
183
+ "4": 23384978,
184
+ "5": 13033305,
185
+ "6": 7798986,
186
+ "7": 4827879,
187
+ "8": 3222733,
188
+ "9": 2262098,
189
+ "10": 1651589,
190
+ "11": 1247118,
191
+ "12": 967386,
192
+ "13": 771426,
193
+ "14": 621546,
194
+ "15": 504368,
195
+ "16": 409418,
196
+ "17": 332421,
197
+ "18": 271277,
198
+ "19": 222138,
199
+ "20": 183772,
200
+ "21": 152433,
201
+ "22": 128423,
202
+ "23": 108428,
203
+ "24": 93487,
204
+ "25": 79093,
205
+ "26": 67728,
206
+ "27": 56196,
207
+ "28": 47634,
208
+ "29": 40579,
209
+ "30": 34355,
210
+ "31": 28984,
211
+ "32": 24565,
212
+ "33": 20972,
213
+ "34": 17931,
214
+ "35": 14995,
215
+ "36": 12377,
216
+ "37": 10307,
217
+ "38": 8797,
218
+ "39": 7610,
219
+ "40": 6846,
220
+ "41": 6271,
221
+ "42": 5855,
222
+ "43": 5378,
223
+ "44": 5294,
224
+ "45": 4945,
225
+ "46": 4528,
226
+ "47": 4172,
227
+ "48": 3883,
228
+ "49": 3522,
229
+ "50": 3246,
230
+ "51": 2948,
231
+ "52": 2646,
232
+ "53": 2401,
233
+ "54": 2102,
234
+ "55": 1889,
235
+ "56": 1689,
236
+ "57": 1444,
237
+ "58": 1354,
238
+ "59": 1166,
239
+ "60": 966,
240
+ "61": 796,
241
+ "62": 695,
242
+ "63": 629,
243
+ "64": 585,
244
+ "65": 531,
245
+ "66": 518,
246
+ "67": 482,
247
+ "68": 442,
248
+ "69": 385,
249
+ "70": 358,
250
+ "71": 335,
251
+ "72": 307,
252
+ "73": 267,
253
+ "74": 271,
254
+ "75": 220,
255
+ "76": 210,
256
+ "77": 180,
257
+ "78": 147,
258
+ "79": 124,
259
+ "80": 116,
260
+ "81": 112,
261
+ "82": 93,
262
+ "83": 69,
263
+ "84": 56,
264
+ "85": 52,
265
+ "86": 23,
266
+ "87": 17,
267
+ "88": 14,
268
+ "89": 11,
269
+ "90": 14,
270
+ "91": 6,
271
+ "92": 6,
272
+ "93": 3,
273
+ "94": 6,
274
+ "95": 1
275
+ },
276
+ "56": {
277
+ "0": 12465097246,
278
+ "1": 835084317,
279
+ "2": 254687121,
280
+ "3": 121720894,
281
+ "4": 71341732,
282
+ "5": 45465642,
283
+ "6": 31016406,
284
+ "7": 22117585,
285
+ "8": 16576017,
286
+ "9": 12843282,
287
+ "10": 10188871,
288
+ "11": 8166753,
289
+ "12": 6639505,
290
+ "13": 5403165,
291
+ "14": 4423601,
292
+ "15": 3641816,
293
+ "16": 2982294,
294
+ "17": 2495500,
295
+ "18": 2107822,
296
+ "19": 1777118,
297
+ "20": 1527177,
298
+ "21": 1320511,
299
+ "22": 1154409,
300
+ "23": 1016008,
301
+ "24": 902921,
302
+ "25": 805297,
303
+ "26": 717731,
304
+ "27": 639994,
305
+ "28": 578216,
306
+ "29": 522654,
307
+ "30": 471731,
308
+ "31": 430710,
309
+ "32": 391310,
310
+ "33": 360727,
311
+ "34": 333244,
312
+ "35": 306947,
313
+ "36": 285386,
314
+ "37": 266777,
315
+ "38": 248721,
316
+ "39": 231377,
317
+ "40": 213535,
318
+ "41": 197555,
319
+ "42": 182232,
320
+ "43": 168988,
321
+ "44": 156079,
322
+ "45": 144746,
323
+ "46": 135302,
324
+ "47": 124226,
325
+ "48": 114096,
326
+ "49": 104673,
327
+ "50": 95005,
328
+ "51": 87224,
329
+ "52": 81168,
330
+ "53": 76076,
331
+ "54": 71286,
332
+ "55": 67529,
333
+ "56": 64050,
334
+ "57": 62041,
335
+ "58": 58650,
336
+ "59": 55931,
337
+ "60": 51249,
338
+ "61": 47542,
339
+ "62": 44191,
340
+ "63": 41598,
341
+ "64": 38416,
342
+ "65": 36328,
343
+ "66": 33839,
344
+ "67": 32088,
345
+ "68": 30559,
346
+ "69": 27881,
347
+ "70": 26103,
348
+ "71": 24152,
349
+ "72": 22520,
350
+ "73": 20886,
351
+ "74": 19169,
352
+ "75": 17738,
353
+ "76": 16636,
354
+ "77": 15532,
355
+ "78": 14619,
356
+ "79": 14389,
357
+ "80": 13560,
358
+ "81": 13208,
359
+ "82": 12245,
360
+ "83": 11275,
361
+ "84": 10523,
362
+ "85": 10108,
363
+ "86": 9176,
364
+ "87": 8790,
365
+ "88": 8448,
366
+ "89": 8110,
367
+ "90": 7575,
368
+ "91": 7354,
369
+ "92": 6483,
370
+ "93": 6061,
371
+ "94": 5352,
372
+ "95": 5181,
373
+ "96": 4845,
374
+ "97": 4594,
375
+ "98": 4342,
376
+ "99": 4193,
377
+ "100": 3899,
378
+ "101": 3674,
379
+ "102": 3565,
380
+ "103": 3285,
381
+ "104": 3059,
382
+ "105": 2778,
383
+ "106": 2658,
384
+ "107": 2485,
385
+ "108": 2345,
386
+ "109": 2303,
387
+ "110": 2210,
388
+ "111": 2095,
389
+ "112": 1975,
390
+ "113": 1975,
391
+ "114": 2058,
392
+ "115": 1969,
393
+ "116": 1914,
394
+ "117": 1934,
395
+ "118": 1928,
396
+ "119": 1914,
397
+ "120": 1954,
398
+ "121": 1943,
399
+ "122": 1997,
400
+ "123": 2085,
401
+ "124": 1841,
402
+ "125": 1728,
403
+ "126": 1603,
404
+ "127": 1530,
405
+ "128": 1426,
406
+ "129": 1355,
407
+ "130": 1309,
408
+ "131": 1340,
409
+ "132": 1256,
410
+ "133": 1260,
411
+ "134": 1219,
412
+ "135": 1086,
413
+ "136": 1079,
414
+ "137": 1004,
415
+ "138": 987,
416
+ "139": 996,
417
+ "140": 886,
418
+ "141": 841,
419
+ "142": 786,
420
+ "143": 799,
421
+ "144": 882,
422
+ "145": 782,
423
+ "146": 718,
424
+ "147": 672,
425
+ "148": 629,
426
+ "149": 578,
427
+ "150": 592,
428
+ "151": 602,
429
+ "152": 564,
430
+ "153": 573,
431
+ "154": 551,
432
+ "155": 484,
433
+ "156": 474,
434
+ "157": 435,
435
+ "158": 410,
436
+ "159": 376,
437
+ "160": 348,
438
+ "161": 366,
439
+ "162": 299,
440
+ "163": 304,
441
+ "164": 280,
442
+ "165": 301,
443
+ "166": 298,
444
+ "167": 266,
445
+ "168": 259,
446
+ "169": 288,
447
+ "170": 259,
448
+ "171": 232,
449
+ "172": 249,
450
+ "173": 229,
451
+ "174": 197,
452
+ "175": 254,
453
+ "176": 204,
454
+ "177": 211,
455
+ "178": 208,
456
+ "179": 199,
457
+ "180": 183,
458
+ "181": 169,
459
+ "182": 169,
460
+ "183": 169,
461
+ "184": 120,
462
+ "185": 119,
463
+ "186": 151,
464
+ "187": 131,
465
+ "188": 126,
466
+ "189": 122,
467
+ "190": 107,
468
+ "191": 105,
469
+ "192": 103,
470
+ "193": 87,
471
+ "194": 71,
472
+ "195": 62,
473
+ "196": 59,
474
+ "197": 51,
475
+ "198": 40,
476
+ "199": 49,
477
+ "200": 44,
478
+ "201": 45,
479
+ "202": 43,
480
+ "203": 42,
481
+ "204": 36,
482
+ "205": 45,
483
+ "206": 36,
484
+ "207": 37,
485
+ "208": 38,
486
+ "209": 32,
487
+ "210": 27,
488
+ "211": 25,
489
+ "212": 21,
490
+ "213": 19,
491
+ "214": 30,
492
+ "215": 16,
493
+ "216": 20,
494
+ "217": 15,
495
+ "218": 14,
496
+ "219": 6,
497
+ "220": 8,
498
+ "221": 5,
499
+ "222": 3,
500
+ "223": 2
501
+ },
502
+ "64": {
503
+ "0": 12134170560,
504
+ "1": 910355445,
505
+ "2": 297133671,
506
+ "3": 145184087,
507
+ "4": 87626341,
508
+ "5": 57746135,
509
+ "6": 40495922,
510
+ "7": 29156512,
511
+ "8": 21919906,
512
+ "9": 16973043,
513
+ "10": 13535308,
514
+ "11": 11038546,
515
+ "12": 9149626,
516
+ "13": 7600687,
517
+ "14": 6410824,
518
+ "15": 5491781,
519
+ "16": 4677502,
520
+ "17": 3997198,
521
+ "18": 3443407,
522
+ "19": 2925959,
523
+ "20": 2507301,
524
+ "21": 2160448,
525
+ "22": 1878716,
526
+ "23": 1648075,
527
+ "24": 1450872,
528
+ "25": 1275043,
529
+ "26": 1133498,
530
+ "27": 1015835,
531
+ "28": 914243,
532
+ "29": 833304,
533
+ "30": 760872,
534
+ "31": 691863,
535
+ "32": 630584,
536
+ "33": 577966,
537
+ "34": 528643,
538
+ "35": 485362,
539
+ "36": 444354,
540
+ "37": 407675,
541
+ "38": 377100,
542
+ "39": 351641,
543
+ "40": 326893,
544
+ "41": 305689,
545
+ "42": 285689,
546
+ "43": 266757,
547
+ "44": 249514,
548
+ "45": 235532,
549
+ "46": 223892,
550
+ "47": 211932,
551
+ "48": 200323,
552
+ "49": 189578,
553
+ "50": 178068,
554
+ "51": 167402,
555
+ "52": 158785,
556
+ "53": 149971,
557
+ "54": 140597,
558
+ "55": 131198,
559
+ "56": 124442,
560
+ "57": 118109,
561
+ "58": 111071,
562
+ "59": 104882,
563
+ "60": 97607,
564
+ "61": 91490,
565
+ "62": 85286,
566
+ "63": 79531,
567
+ "64": 74921,
568
+ "65": 69722,
569
+ "66": 67061,
570
+ "67": 62855,
571
+ "68": 59431,
572
+ "69": 56425,
573
+ "70": 53389,
574
+ "71": 52205,
575
+ "72": 49130,
576
+ "73": 47540,
577
+ "74": 46130,
578
+ "75": 44031,
579
+ "76": 41069,
580
+ "77": 38590,
581
+ "78": 36372,
582
+ "79": 34739,
583
+ "80": 32483,
584
+ "81": 30821,
585
+ "82": 29084,
586
+ "83": 27658,
587
+ "84": 26356,
588
+ "85": 25296,
589
+ "86": 24161,
590
+ "87": 22766,
591
+ "88": 21596,
592
+ "89": 20576,
593
+ "90": 19734,
594
+ "91": 18715,
595
+ "92": 17676,
596
+ "93": 16389,
597
+ "94": 15235,
598
+ "95": 14115,
599
+ "96": 13051,
600
+ "97": 12336,
601
+ "98": 11769,
602
+ "99": 10974,
603
+ "100": 10731,
604
+ "101": 9897,
605
+ "102": 9661,
606
+ "103": 9456,
607
+ "104": 9255,
608
+ "105": 9143,
609
+ "106": 8863,
610
+ "107": 8535,
611
+ "108": 8059,
612
+ "109": 7377,
613
+ "110": 7024,
614
+ "111": 6470,
615
+ "112": 6426,
616
+ "113": 6009,
617
+ "114": 5748,
618
+ "115": 5535,
619
+ "116": 5244,
620
+ "117": 4876,
621
+ "118": 4586,
622
+ "119": 4234,
623
+ "120": 4118,
624
+ "121": 3789,
625
+ "122": 3695,
626
+ "123": 3622,
627
+ "124": 3493,
628
+ "125": 3318,
629
+ "126": 3359,
630
+ "127": 3420,
631
+ "128": 3353,
632
+ "129": 3224,
633
+ "130": 3222,
634
+ "131": 3038,
635
+ "132": 2831,
636
+ "133": 2743,
637
+ "134": 2751,
638
+ "135": 2703,
639
+ "136": 2517,
640
+ "137": 2404,
641
+ "138": 2360,
642
+ "139": 2069,
643
+ "140": 2037,
644
+ "141": 1829,
645
+ "142": 1693,
646
+ "143": 1599,
647
+ "144": 1588,
648
+ "145": 1482,
649
+ "146": 1408,
650
+ "147": 1386,
651
+ "148": 1339,
652
+ "149": 1401,
653
+ "150": 1313,
654
+ "151": 1276,
655
+ "152": 1276,
656
+ "153": 1179,
657
+ "154": 1242,
658
+ "155": 1267,
659
+ "156": 1184,
660
+ "157": 1245,
661
+ "158": 1187,
662
+ "159": 1113,
663
+ "160": 1095,
664
+ "161": 1059,
665
+ "162": 938,
666
+ "163": 958,
667
+ "164": 906,
668
+ "165": 920,
669
+ "166": 941,
670
+ "167": 905,
671
+ "168": 885,
672
+ "169": 873,
673
+ "170": 794,
674
+ "171": 741,
675
+ "172": 773,
676
+ "173": 713,
677
+ "174": 694,
678
+ "175": 689,
679
+ "176": 741,
680
+ "177": 770,
681
+ "178": 735,
682
+ "179": 747,
683
+ "180": 704,
684
+ "181": 670,
685
+ "182": 652,
686
+ "183": 635,
687
+ "184": 633,
688
+ "185": 682,
689
+ "186": 598,
690
+ "187": 590,
691
+ "188": 541,
692
+ "189": 526,
693
+ "190": 495,
694
+ "191": 508,
695
+ "192": 492,
696
+ "193": 501,
697
+ "194": 443,
698
+ "195": 444,
699
+ "196": 399,
700
+ "197": 363,
701
+ "198": 357,
702
+ "199": 338,
703
+ "200": 292,
704
+ "201": 273,
705
+ "202": 288,
706
+ "203": 292,
707
+ "204": 280,
708
+ "205": 260,
709
+ "206": 278,
710
+ "207": 243,
711
+ "208": 212,
712
+ "209": 241,
713
+ "210": 217,
714
+ "211": 189,
715
+ "212": 195,
716
+ "213": 181,
717
+ "214": 179,
718
+ "215": 238,
719
+ "216": 196,
720
+ "217": 195,
721
+ "218": 181,
722
+ "219": 191,
723
+ "220": 158,
724
+ "221": 154,
725
+ "222": 178,
726
+ "223": 150,
727
+ "224": 149,
728
+ "225": 155,
729
+ "226": 184,
730
+ "227": 125,
731
+ "228": 154,
732
+ "229": 135,
733
+ "230": 153,
734
+ "231": 151,
735
+ "232": 153,
736
+ "233": 124,
737
+ "234": 110,
738
+ "235": 87,
739
+ "236": 95,
740
+ "237": 76,
741
+ "238": 75,
742
+ "239": 69,
743
+ "240": 67,
744
+ "241": 60,
745
+ "242": 36,
746
+ "243": 42,
747
+ "244": 55,
748
+ "245": 41,
749
+ "246": 58,
750
+ "247": 46,
751
+ "248": 37,
752
+ "249": 33,
753
+ "250": 29,
754
+ "251": 23,
755
+ "252": 13,
756
+ "253": 3,
757
+ "254": 11,
758
+ "255": 9,
759
+ "256": 2
760
+ }
761
+ }
counts/nwpu_max.json ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "4": {
3
+ "max": 6,
4
+ "name": [
5
+ "0701.jpg"
6
+ ],
7
+ "x": 976,
8
+ "y": 1527
9
+ },
10
+ "7": {
11
+ "max": 10,
12
+ "name": [
13
+ "0181.jpg"
14
+ ],
15
+ "x": 639,
16
+ "y": 1531
17
+ },
18
+ "8": {
19
+ "max": 11,
20
+ "name": [
21
+ "1838.jpg"
22
+ ],
23
+ "x": 815,
24
+ "y": 1001
25
+ },
26
+ "14": {
27
+ "max": 23,
28
+ "name": [
29
+ "1838.jpg"
30
+ ],
31
+ "x": 995,
32
+ "y": 1544
33
+ },
34
+ "16": {
35
+ "max": 29,
36
+ "name": [
37
+ "1838.jpg"
38
+ ],
39
+ "x": 991,
40
+ "y": 1544
41
+ },
42
+ "28": {
43
+ "max": 75,
44
+ "name": [
45
+ "1838.jpg"
46
+ ],
47
+ "x": 1003,
48
+ "y": 1706
49
+ },
50
+ "32": {
51
+ "max": 95,
52
+ "name": [
53
+ "1838.jpg"
54
+ ],
55
+ "x": 1003,
56
+ "y": 1704
57
+ },
58
+ "56": {
59
+ "max": 223,
60
+ "name": [
61
+ "1838.jpg"
62
+ ],
63
+ "x": 993,
64
+ "y": 1702
65
+ },
66
+ "64": {
67
+ "max": 256,
68
+ "name": [
69
+ "1838.jpg"
70
+ ],
71
+ "x": 990,
72
+ "y": 1697
73
+ }
74
+ }
counts/qnrf.json ADDED
@@ -0,0 +1,569 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "1": {
3
+ "0": 2703096261,
4
+ "1": 1007163
5
+ },
6
+ "4": {
7
+ "0": 2677404968,
8
+ "1": 15969215,
9
+ "2": 59807,
10
+ "3": 1384,
11
+ "4": 97,
12
+ "5": 10
13
+ },
14
+ "7": {
15
+ "0": 2635421382,
16
+ "1": 45742892,
17
+ "2": 1492537,
18
+ "3": 114192,
19
+ "4": 14549,
20
+ "5": 2676,
21
+ "6": 675,
22
+ "7": 199,
23
+ "8": 47,
24
+ "9": 7
25
+ },
26
+ "8": {
27
+ "0": 2618651922,
28
+ "1": 57473873,
29
+ "2": 2778844,
30
+ "3": 286508,
31
+ "4": 41982,
32
+ "5": 8626,
33
+ "6": 2306,
34
+ "7": 782,
35
+ "8": 241,
36
+ "9": 77,
37
+ "10": 21,
38
+ "11": 3
39
+ },
40
+ "14": {
41
+ "0": 2502680139,
42
+ "1": 128220308,
43
+ "2": 19112473,
44
+ "3": 5245278,
45
+ "4": 1729894,
46
+ "5": 624274,
47
+ "6": 238250,
48
+ "7": 97230,
49
+ "8": 41347,
50
+ "9": 19325,
51
+ "10": 9833,
52
+ "11": 5696,
53
+ "12": 3361,
54
+ "13": 1972,
55
+ "14": 1035,
56
+ "15": 547,
57
+ "16": 340,
58
+ "17": 212,
59
+ "18": 112,
60
+ "19": 71,
61
+ "20": 42,
62
+ "21": 30,
63
+ "22": 14,
64
+ "23": 12,
65
+ "24": 3,
66
+ "25": 3
67
+ },
68
+ "16": {
69
+ "0": 2461366525,
70
+ "1": 149295686,
71
+ "2": 26297062,
72
+ "3": 8379921,
73
+ "4": 3199756,
74
+ "5": 1324146,
75
+ "6": 583234,
76
+ "7": 267593,
77
+ "8": 128122,
78
+ "9": 62843,
79
+ "10": 32265,
80
+ "11": 16540,
81
+ "12": 9297,
82
+ "13": 5835,
83
+ "14": 4037,
84
+ "15": 2616,
85
+ "16": 1660,
86
+ "17": 1066,
87
+ "18": 639,
88
+ "19": 349,
89
+ "20": 203,
90
+ "21": 183,
91
+ "22": 121,
92
+ "23": 80,
93
+ "24": 44,
94
+ "25": 30,
95
+ "26": 11,
96
+ "27": 9,
97
+ "28": 10,
98
+ "29": 6
99
+ },
100
+ "28": {
101
+ "0": 2217981619,
102
+ "1": 242958596,
103
+ "2": 68708089,
104
+ "3": 31034654,
105
+ "4": 17007626,
106
+ "5": 10317353,
107
+ "6": 6556090,
108
+ "7": 4298832,
109
+ "8": 2899688,
110
+ "9": 2014576,
111
+ "10": 1411981,
112
+ "11": 1007963,
113
+ "12": 718139,
114
+ "13": 516552,
115
+ "14": 375188,
116
+ "15": 273595,
117
+ "16": 199599,
118
+ "17": 144002,
119
+ "18": 106107,
120
+ "19": 79309,
121
+ "20": 60015,
122
+ "21": 45839,
123
+ "22": 35538,
124
+ "23": 27006,
125
+ "24": 21141,
126
+ "25": 16063,
127
+ "26": 11666,
128
+ "27": 8786,
129
+ "28": 6812,
130
+ "29": 5341,
131
+ "30": 4314,
132
+ "31": 3339,
133
+ "32": 2718,
134
+ "33": 2165,
135
+ "34": 1611,
136
+ "35": 1444,
137
+ "36": 1299,
138
+ "37": 1057,
139
+ "38": 930,
140
+ "39": 804,
141
+ "40": 590,
142
+ "41": 475,
143
+ "42": 361,
144
+ "43": 297,
145
+ "44": 242,
146
+ "45": 166,
147
+ "46": 125,
148
+ "47": 108,
149
+ "48": 81,
150
+ "49": 90,
151
+ "50": 74,
152
+ "51": 46,
153
+ "52": 38,
154
+ "53": 24,
155
+ "54": 13,
156
+ "55": 7,
157
+ "56": 2
158
+ },
159
+ "32": {
160
+ "0": 2142175706,
161
+ "1": 263796436,
162
+ "2": 81041731,
163
+ "3": 38689542,
164
+ "4": 21861748,
165
+ "5": 13731448,
166
+ "6": 9285756,
167
+ "7": 6423842,
168
+ "8": 4541778,
169
+ "9": 3251892,
170
+ "10": 2387659,
171
+ "11": 1795213,
172
+ "12": 1359736,
173
+ "13": 1036607,
174
+ "14": 790266,
175
+ "15": 606226,
176
+ "16": 469220,
177
+ "17": 361731,
178
+ "18": 281834,
179
+ "19": 218860,
180
+ "20": 168254,
181
+ "21": 130270,
182
+ "22": 100263,
183
+ "23": 78196,
184
+ "24": 61822,
185
+ "25": 49558,
186
+ "26": 39186,
187
+ "27": 32271,
188
+ "28": 26464,
189
+ "29": 21939,
190
+ "30": 17726,
191
+ "31": 14747,
192
+ "32": 11705,
193
+ "33": 9539,
194
+ "34": 7368,
195
+ "35": 5935,
196
+ "36": 4774,
197
+ "37": 3727,
198
+ "38": 3275,
199
+ "39": 2605,
200
+ "40": 2408,
201
+ "41": 1893,
202
+ "42": 1440,
203
+ "43": 1278,
204
+ "44": 1070,
205
+ "45": 915,
206
+ "46": 740,
207
+ "47": 619,
208
+ "48": 507,
209
+ "49": 484,
210
+ "50": 330,
211
+ "51": 374,
212
+ "52": 287,
213
+ "53": 244,
214
+ "54": 223,
215
+ "55": 186,
216
+ "56": 136,
217
+ "57": 120,
218
+ "58": 103,
219
+ "59": 100,
220
+ "60": 88,
221
+ "61": 35,
222
+ "62": 16,
223
+ "63": 23,
224
+ "64": 3,
225
+ "65": 4
226
+ },
227
+ "56": {
228
+ "0": 1753079415,
229
+ "1": 326506415,
230
+ "2": 132989783,
231
+ "3": 73569912,
232
+ "4": 47875219,
233
+ "5": 33905462,
234
+ "6": 25324896,
235
+ "7": 19351405,
236
+ "8": 14991521,
237
+ "9": 11859011,
238
+ "10": 9643110,
239
+ "11": 7988193,
240
+ "12": 6724565,
241
+ "13": 5737528,
242
+ "14": 4905936,
243
+ "15": 4234117,
244
+ "16": 3678710,
245
+ "17": 3233842,
246
+ "18": 2856180,
247
+ "19": 2528729,
248
+ "20": 2238483,
249
+ "21": 1974680,
250
+ "22": 1738379,
251
+ "23": 1522952,
252
+ "24": 1334635,
253
+ "25": 1171843,
254
+ "26": 1038446,
255
+ "27": 924884,
256
+ "28": 828510,
257
+ "29": 749323,
258
+ "30": 680155,
259
+ "31": 619173,
260
+ "32": 558209,
261
+ "33": 507896,
262
+ "34": 463642,
263
+ "35": 419398,
264
+ "36": 380125,
265
+ "37": 347601,
266
+ "38": 318828,
267
+ "39": 293043,
268
+ "40": 272483,
269
+ "41": 250724,
270
+ "42": 228696,
271
+ "43": 206140,
272
+ "44": 184636,
273
+ "45": 165534,
274
+ "46": 149696,
275
+ "47": 135099,
276
+ "48": 121824,
277
+ "49": 110240,
278
+ "50": 98425,
279
+ "51": 88515,
280
+ "52": 79279,
281
+ "53": 70978,
282
+ "54": 64994,
283
+ "55": 59099,
284
+ "56": 53268,
285
+ "57": 48134,
286
+ "58": 43611,
287
+ "59": 38300,
288
+ "60": 34909,
289
+ "61": 31681,
290
+ "62": 28393,
291
+ "63": 24688,
292
+ "64": 21934,
293
+ "65": 19803,
294
+ "66": 17598,
295
+ "67": 15593,
296
+ "68": 14189,
297
+ "69": 13168,
298
+ "70": 12483,
299
+ "71": 11762,
300
+ "72": 11066,
301
+ "73": 10447,
302
+ "74": 9606,
303
+ "75": 8747,
304
+ "76": 7574,
305
+ "77": 6921,
306
+ "78": 6340,
307
+ "79": 6088,
308
+ "80": 5448,
309
+ "81": 5380,
310
+ "82": 5144,
311
+ "83": 5114,
312
+ "84": 4775,
313
+ "85": 4632,
314
+ "86": 4332,
315
+ "87": 4082,
316
+ "88": 3949,
317
+ "89": 3821,
318
+ "90": 3476,
319
+ "91": 3406,
320
+ "92": 2973,
321
+ "93": 2766,
322
+ "94": 2489,
323
+ "95": 2253,
324
+ "96": 2087,
325
+ "97": 1763,
326
+ "98": 1560,
327
+ "99": 1322,
328
+ "100": 1243,
329
+ "101": 1150,
330
+ "102": 994,
331
+ "103": 794,
332
+ "104": 589,
333
+ "105": 538,
334
+ "106": 416,
335
+ "107": 359,
336
+ "108": 335,
337
+ "109": 309,
338
+ "110": 310,
339
+ "111": 280,
340
+ "112": 269,
341
+ "113": 279,
342
+ "114": 233,
343
+ "115": 198,
344
+ "116": 208,
345
+ "117": 211,
346
+ "118": 166,
347
+ "119": 128,
348
+ "120": 127,
349
+ "121": 119,
350
+ "122": 145,
351
+ "123": 159,
352
+ "124": 130,
353
+ "125": 115,
354
+ "126": 124,
355
+ "127": 132,
356
+ "128": 130,
357
+ "129": 114,
358
+ "130": 136,
359
+ "131": 113,
360
+ "132": 119,
361
+ "133": 92,
362
+ "134": 109,
363
+ "135": 94,
364
+ "136": 112,
365
+ "137": 108,
366
+ "138": 107,
367
+ "139": 114,
368
+ "140": 102,
369
+ "141": 63,
370
+ "142": 43,
371
+ "143": 46,
372
+ "144": 34,
373
+ "145": 17,
374
+ "146": 17,
375
+ "147": 4,
376
+ "148": 4
377
+ },
378
+ "64": {
379
+ "0": 1645580394,
380
+ "1": 332121950,
381
+ "2": 143857376,
382
+ "3": 82342244,
383
+ "4": 54254902,
384
+ "5": 38847202,
385
+ "6": 29417465,
386
+ "7": 23205846,
387
+ "8": 18694855,
388
+ "9": 15141642,
389
+ "10": 12371576,
390
+ "11": 10229329,
391
+ "12": 8647553,
392
+ "13": 7325344,
393
+ "14": 6295327,
394
+ "15": 5516930,
395
+ "16": 4865082,
396
+ "17": 4309391,
397
+ "18": 3842162,
398
+ "19": 3406684,
399
+ "20": 3033028,
400
+ "21": 2735522,
401
+ "22": 2473336,
402
+ "23": 2242708,
403
+ "24": 2042061,
404
+ "25": 1862630,
405
+ "26": 1687997,
406
+ "27": 1529651,
407
+ "28": 1377678,
408
+ "29": 1246699,
409
+ "30": 1127615,
410
+ "31": 1021519,
411
+ "32": 919786,
412
+ "33": 835229,
413
+ "34": 758589,
414
+ "35": 694245,
415
+ "36": 637642,
416
+ "37": 589662,
417
+ "38": 547952,
418
+ "39": 507110,
419
+ "40": 467377,
420
+ "41": 431426,
421
+ "42": 399251,
422
+ "43": 369645,
423
+ "44": 345626,
424
+ "45": 320928,
425
+ "46": 300584,
426
+ "47": 279405,
427
+ "48": 261128,
428
+ "49": 245246,
429
+ "50": 230330,
430
+ "51": 216329,
431
+ "52": 202315,
432
+ "53": 188342,
433
+ "54": 175479,
434
+ "55": 164216,
435
+ "56": 151015,
436
+ "57": 138762,
437
+ "58": 128074,
438
+ "59": 118213,
439
+ "60": 109407,
440
+ "61": 100053,
441
+ "62": 91903,
442
+ "63": 83292,
443
+ "64": 75832,
444
+ "65": 68006,
445
+ "66": 61400,
446
+ "67": 55742,
447
+ "68": 51271,
448
+ "69": 47305,
449
+ "70": 43974,
450
+ "71": 39955,
451
+ "72": 36911,
452
+ "73": 34035,
453
+ "74": 30928,
454
+ "75": 28558,
455
+ "76": 26104,
456
+ "77": 24211,
457
+ "78": 22590,
458
+ "79": 20897,
459
+ "80": 19153,
460
+ "81": 17657,
461
+ "82": 16849,
462
+ "83": 15301,
463
+ "84": 14235,
464
+ "85": 13049,
465
+ "86": 11929,
466
+ "87": 10779,
467
+ "88": 9912,
468
+ "89": 9146,
469
+ "90": 8247,
470
+ "91": 7534,
471
+ "92": 7104,
472
+ "93": 6609,
473
+ "94": 6159,
474
+ "95": 5758,
475
+ "96": 5510,
476
+ "97": 5528,
477
+ "98": 5293,
478
+ "99": 4973,
479
+ "100": 4606,
480
+ "101": 4275,
481
+ "102": 4271,
482
+ "103": 4037,
483
+ "104": 3971,
484
+ "105": 3787,
485
+ "106": 3970,
486
+ "107": 3630,
487
+ "108": 3605,
488
+ "109": 3351,
489
+ "110": 3229,
490
+ "111": 2970,
491
+ "112": 2963,
492
+ "113": 3005,
493
+ "114": 2790,
494
+ "115": 2728,
495
+ "116": 2547,
496
+ "117": 2315,
497
+ "118": 2133,
498
+ "119": 1910,
499
+ "120": 1701,
500
+ "121": 1579,
501
+ "122": 1382,
502
+ "123": 1253,
503
+ "124": 1198,
504
+ "125": 1048,
505
+ "126": 901,
506
+ "127": 847,
507
+ "128": 761,
508
+ "129": 656,
509
+ "130": 559,
510
+ "131": 543,
511
+ "132": 509,
512
+ "133": 497,
513
+ "134": 357,
514
+ "135": 353,
515
+ "136": 321,
516
+ "137": 252,
517
+ "138": 262,
518
+ "139": 215,
519
+ "140": 175,
520
+ "141": 188,
521
+ "142": 141,
522
+ "143": 138,
523
+ "144": 124,
524
+ "145": 141,
525
+ "146": 146,
526
+ "147": 147,
527
+ "148": 149,
528
+ "149": 156,
529
+ "150": 150,
530
+ "151": 122,
531
+ "152": 122,
532
+ "153": 118,
533
+ "154": 115,
534
+ "155": 142,
535
+ "156": 127,
536
+ "157": 105,
537
+ "158": 108,
538
+ "159": 96,
539
+ "160": 111,
540
+ "161": 106,
541
+ "162": 106,
542
+ "163": 99,
543
+ "164": 112,
544
+ "165": 108,
545
+ "166": 97,
546
+ "167": 106,
547
+ "168": 102,
548
+ "169": 109,
549
+ "170": 81,
550
+ "171": 118,
551
+ "172": 79,
552
+ "173": 67,
553
+ "174": 86,
554
+ "175": 34,
555
+ "176": 44,
556
+ "177": 22,
557
+ "178": 20,
558
+ "179": 20,
559
+ "180": 22,
560
+ "181": 6,
561
+ "182": 18,
562
+ "183": 12,
563
+ "184": 4,
564
+ "185": 6,
565
+ "186": 2,
566
+ "187": 2,
567
+ "188": 2
568
+ }
569
+ }
counts/qnrf_max.json ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "4": {
3
+ "max": 5,
4
+ "name": [
5
+ "0862.jpg"
6
+ ],
7
+ "x": 525,
8
+ "y": 892
9
+ },
10
+ "7": {
11
+ "max": 9,
12
+ "name": [
13
+ "0215.jpg"
14
+ ],
15
+ "x": 339,
16
+ "y": 701
17
+ },
18
+ "8": {
19
+ "max": 11,
20
+ "name": [
21
+ "0215.jpg"
22
+ ],
23
+ "x": 339,
24
+ "y": 701
25
+ },
26
+ "14": {
27
+ "max": 25,
28
+ "name": [
29
+ "0215.jpg"
30
+ ],
31
+ "x": 332,
32
+ "y": 697
33
+ },
34
+ "16": {
35
+ "max": 29,
36
+ "name": [
37
+ "0215.jpg"
38
+ ],
39
+ "x": 331,
40
+ "y": 697
41
+ },
42
+ "28": {
43
+ "max": 56,
44
+ "name": [
45
+ "0330.jpg"
46
+ ],
47
+ "x": 336,
48
+ "y": 1063
49
+ },
50
+ "32": {
51
+ "max": 65,
52
+ "name": [
53
+ "0931.jpg"
54
+ ],
55
+ "x": 730,
56
+ "y": 1077
57
+ },
58
+ "56": {
59
+ "max": 148,
60
+ "name": [
61
+ "0931.jpg"
62
+ ],
63
+ "x": 725,
64
+ "y": 1084
65
+ },
66
+ "64": {
67
+ "max": 188,
68
+ "name": [
69
+ "0931.jpg"
70
+ ],
71
+ "x": 702,
72
+ "y": 1078
73
+ }
74
+ }
counts/sha.json ADDED
@@ -0,0 +1,578 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "1": {
3
+ "0": 221398495,
4
+ "1": 162337
5
+ },
6
+ "4": {
7
+ "0": 217404440,
8
+ "1": 2560651,
9
+ "2": 15919,
10
+ "3": 525,
11
+ "4": 13
12
+ },
13
+ "7": {
14
+ "0": 210823747,
15
+ "1": 7296555,
16
+ "2": 246197,
17
+ "3": 33928,
18
+ "4": 5933,
19
+ "5": 1100,
20
+ "6": 185,
21
+ "7": 17,
22
+ "8": 2
23
+ },
24
+ "8": {
25
+ "0": 208176274,
26
+ "1": 9201723,
27
+ "2": 414603,
28
+ "3": 70943,
29
+ "4": 15652,
30
+ "5": 3870,
31
+ "6": 919,
32
+ "7": 204,
33
+ "8": 33,
34
+ "9": 13,
35
+ "10": 2
36
+ },
37
+ "14": {
38
+ "0": 189475328,
39
+ "1": 21483041,
40
+ "2": 2585228,
41
+ "3": 687491,
42
+ "4": 268045,
43
+ "5": 123917,
44
+ "6": 62564,
45
+ "7": 32805,
46
+ "8": 17268,
47
+ "9": 9346,
48
+ "10": 5150,
49
+ "11": 2866,
50
+ "12": 1541,
51
+ "13": 822,
52
+ "14": 463,
53
+ "15": 198,
54
+ "16": 99,
55
+ "17": 48,
56
+ "18": 30,
57
+ "19": 16,
58
+ "20": 2
59
+ },
60
+ "16": {
61
+ "0": 182715184,
62
+ "1": 25283086,
63
+ "2": 3723263,
64
+ "3": 1075878,
65
+ "4": 428688,
66
+ "5": 212379,
67
+ "6": 116019,
68
+ "7": 66336,
69
+ "8": 38808,
70
+ "9": 22851,
71
+ "10": 13845,
72
+ "11": 8410,
73
+ "12": 5248,
74
+ "13": 3355,
75
+ "14": 1997,
76
+ "15": 1309,
77
+ "16": 818,
78
+ "17": 451,
79
+ "18": 238,
80
+ "19": 113,
81
+ "20": 65,
82
+ "21": 31,
83
+ "22": 16,
84
+ "23": 8,
85
+ "24": 8,
86
+ "25": 4,
87
+ "26": 2,
88
+ "27": 1,
89
+ "28": 1
90
+ },
91
+ "28": {
92
+ "0": 143735006,
93
+ "1": 40200526,
94
+ "2": 11837381,
95
+ "3": 4979488,
96
+ "4": 2499799,
97
+ "5": 1387540,
98
+ "6": 843561,
99
+ "7": 542333,
100
+ "8": 361946,
101
+ "9": 254697,
102
+ "10": 190403,
103
+ "11": 145704,
104
+ "12": 112559,
105
+ "13": 88972,
106
+ "14": 69603,
107
+ "15": 55162,
108
+ "16": 44351,
109
+ "17": 36430,
110
+ "18": 29187,
111
+ "19": 23408,
112
+ "20": 18831,
113
+ "21": 14678,
114
+ "22": 11890,
115
+ "23": 9916,
116
+ "24": 8375,
117
+ "25": 6759,
118
+ "26": 5676,
119
+ "27": 4713,
120
+ "28": 3932,
121
+ "29": 3328,
122
+ "30": 2705,
123
+ "31": 2351,
124
+ "32": 1976,
125
+ "33": 1691,
126
+ "34": 1450,
127
+ "35": 1133,
128
+ "36": 949,
129
+ "37": 790,
130
+ "38": 657,
131
+ "39": 489,
132
+ "40": 324,
133
+ "41": 244,
134
+ "42": 170,
135
+ "43": 144,
136
+ "44": 122,
137
+ "45": 85,
138
+ "46": 63,
139
+ "47": 48,
140
+ "48": 44,
141
+ "49": 38,
142
+ "50": 16,
143
+ "51": 13,
144
+ "52": 14,
145
+ "53": 4,
146
+ "54": 1,
147
+ "55": 1
148
+ },
149
+ "32": {
150
+ "0": 132333123,
151
+ "1": 42343779,
152
+ "2": 14212292,
153
+ "3": 6506905,
154
+ "4": 3472682,
155
+ "5": 2008349,
156
+ "6": 1259746,
157
+ "7": 835061,
158
+ "8": 580977,
159
+ "9": 411224,
160
+ "10": 298598,
161
+ "11": 226760,
162
+ "12": 177966,
163
+ "13": 142932,
164
+ "14": 116283,
165
+ "15": 95533,
166
+ "16": 78053,
167
+ "17": 64149,
168
+ "18": 52398,
169
+ "19": 43187,
170
+ "20": 36642,
171
+ "21": 31022,
172
+ "22": 26409,
173
+ "23": 22474,
174
+ "24": 19080,
175
+ "25": 15785,
176
+ "26": 12983,
177
+ "27": 10905,
178
+ "28": 9540,
179
+ "29": 8242,
180
+ "30": 7113,
181
+ "31": 5838,
182
+ "32": 4817,
183
+ "33": 4147,
184
+ "34": 3635,
185
+ "35": 3160,
186
+ "36": 2800,
187
+ "37": 2258,
188
+ "38": 2086,
189
+ "39": 1884,
190
+ "40": 1789,
191
+ "41": 1749,
192
+ "42": 1451,
193
+ "43": 1284,
194
+ "44": 1097,
195
+ "45": 849,
196
+ "46": 631,
197
+ "47": 498,
198
+ "48": 324,
199
+ "49": 294,
200
+ "50": 208,
201
+ "51": 157,
202
+ "52": 136,
203
+ "53": 119,
204
+ "54": 80,
205
+ "55": 73,
206
+ "56": 68,
207
+ "57": 58,
208
+ "58": 61,
209
+ "59": 69,
210
+ "60": 49,
211
+ "61": 40,
212
+ "62": 24,
213
+ "63": 15,
214
+ "64": 17,
215
+ "65": 5,
216
+ "66": 2
217
+ },
218
+ "56": {
219
+ "0": 82311154,
220
+ "1": 39190874,
221
+ "2": 22002574,
222
+ "3": 12962751,
223
+ "4": 8404996,
224
+ "5": 5832072,
225
+ "6": 4289203,
226
+ "7": 3257294,
227
+ "8": 2542514,
228
+ "9": 1997552,
229
+ "10": 1600093,
230
+ "11": 1286154,
231
+ "12": 1053852,
232
+ "13": 868200,
233
+ "14": 718305,
234
+ "15": 598864,
235
+ "16": 498449,
236
+ "17": 418687,
237
+ "18": 358697,
238
+ "19": 312381,
239
+ "20": 276011,
240
+ "21": 241729,
241
+ "22": 215353,
242
+ "23": 195921,
243
+ "24": 175559,
244
+ "25": 159251,
245
+ "26": 141084,
246
+ "27": 128022,
247
+ "28": 114886,
248
+ "29": 104495,
249
+ "30": 95802,
250
+ "31": 87751,
251
+ "32": 79668,
252
+ "33": 72856,
253
+ "34": 67187,
254
+ "35": 60598,
255
+ "36": 56041,
256
+ "37": 49833,
257
+ "38": 45739,
258
+ "39": 42100,
259
+ "40": 38922,
260
+ "41": 35683,
261
+ "42": 33222,
262
+ "43": 31037,
263
+ "44": 27306,
264
+ "45": 24412,
265
+ "46": 21939,
266
+ "47": 20087,
267
+ "48": 18312,
268
+ "49": 17285,
269
+ "50": 16026,
270
+ "51": 14905,
271
+ "52": 14599,
272
+ "53": 13990,
273
+ "54": 13420,
274
+ "55": 12785,
275
+ "56": 11938,
276
+ "57": 11445,
277
+ "58": 11094,
278
+ "59": 10387,
279
+ "60": 9826,
280
+ "61": 9605,
281
+ "62": 9270,
282
+ "63": 8533,
283
+ "64": 8157,
284
+ "65": 7849,
285
+ "66": 7121,
286
+ "67": 6586,
287
+ "68": 6083,
288
+ "69": 5424,
289
+ "70": 4978,
290
+ "71": 4867,
291
+ "72": 4364,
292
+ "73": 3995,
293
+ "74": 3771,
294
+ "75": 3567,
295
+ "76": 3107,
296
+ "77": 2871,
297
+ "78": 2630,
298
+ "79": 2162,
299
+ "80": 2096,
300
+ "81": 1907,
301
+ "82": 1872,
302
+ "83": 1792,
303
+ "84": 1838,
304
+ "85": 1703,
305
+ "86": 1629,
306
+ "87": 1545,
307
+ "88": 1388,
308
+ "89": 1298,
309
+ "90": 1310,
310
+ "91": 1258,
311
+ "92": 1175,
312
+ "93": 1174,
313
+ "94": 1013,
314
+ "95": 976,
315
+ "96": 856,
316
+ "97": 784,
317
+ "98": 711,
318
+ "99": 692,
319
+ "100": 697,
320
+ "101": 622,
321
+ "102": 639,
322
+ "103": 544,
323
+ "104": 531,
324
+ "105": 476,
325
+ "106": 481,
326
+ "107": 450,
327
+ "108": 443,
328
+ "109": 439,
329
+ "110": 443,
330
+ "111": 358,
331
+ "112": 337,
332
+ "113": 293,
333
+ "114": 264,
334
+ "115": 223,
335
+ "116": 177,
336
+ "117": 140,
337
+ "118": 143,
338
+ "119": 124,
339
+ "120": 118,
340
+ "121": 104,
341
+ "122": 100,
342
+ "123": 96,
343
+ "124": 94,
344
+ "125": 92,
345
+ "126": 69,
346
+ "127": 89,
347
+ "128": 91,
348
+ "129": 85,
349
+ "130": 70,
350
+ "131": 66,
351
+ "132": 51,
352
+ "133": 54,
353
+ "134": 77,
354
+ "135": 60,
355
+ "136": 69,
356
+ "137": 62,
357
+ "138": 75,
358
+ "139": 83,
359
+ "140": 84,
360
+ "141": 77,
361
+ "142": 63,
362
+ "143": 51,
363
+ "144": 51,
364
+ "145": 68,
365
+ "146": 44,
366
+ "147": 45,
367
+ "148": 35,
368
+ "149": 38,
369
+ "150": 39,
370
+ "151": 39,
371
+ "152": 22,
372
+ "153": 12,
373
+ "154": 19,
374
+ "155": 24,
375
+ "156": 15,
376
+ "157": 4,
377
+ "158": 3,
378
+ "159": 1
379
+ },
380
+ "64": {
381
+ "0": 71204848,
382
+ "1": 35431716,
383
+ "2": 22345768,
384
+ "3": 14110543,
385
+ "4": 9458039,
386
+ "5": 6781297,
387
+ "6": 5068480,
388
+ "7": 3922313,
389
+ "8": 3115679,
390
+ "9": 2546969,
391
+ "10": 2092914,
392
+ "11": 1728554,
393
+ "12": 1445669,
394
+ "13": 1226006,
395
+ "14": 1027888,
396
+ "15": 880413,
397
+ "16": 758676,
398
+ "17": 651263,
399
+ "18": 560175,
400
+ "19": 481484,
401
+ "20": 415366,
402
+ "21": 360995,
403
+ "22": 319926,
404
+ "23": 281587,
405
+ "24": 249589,
406
+ "25": 222763,
407
+ "26": 201505,
408
+ "27": 186993,
409
+ "28": 172894,
410
+ "29": 160066,
411
+ "30": 148490,
412
+ "31": 135929,
413
+ "32": 125730,
414
+ "33": 116554,
415
+ "34": 109632,
416
+ "35": 101625,
417
+ "36": 93920,
418
+ "37": 86856,
419
+ "38": 80031,
420
+ "39": 73701,
421
+ "40": 68720,
422
+ "41": 62813,
423
+ "42": 58001,
424
+ "43": 53537,
425
+ "44": 49124,
426
+ "45": 45340,
427
+ "46": 42598,
428
+ "47": 39746,
429
+ "48": 37319,
430
+ "49": 35173,
431
+ "50": 32861,
432
+ "51": 29710,
433
+ "52": 27037,
434
+ "53": 24220,
435
+ "54": 22338,
436
+ "55": 20642,
437
+ "56": 19097,
438
+ "57": 17737,
439
+ "58": 16334,
440
+ "59": 16276,
441
+ "60": 15705,
442
+ "61": 14837,
443
+ "62": 13992,
444
+ "63": 13180,
445
+ "64": 12950,
446
+ "65": 12540,
447
+ "66": 12527,
448
+ "67": 12219,
449
+ "68": 11564,
450
+ "69": 10978,
451
+ "70": 10465,
452
+ "71": 9857,
453
+ "72": 9330,
454
+ "73": 9088,
455
+ "74": 8851,
456
+ "75": 8715,
457
+ "76": 8399,
458
+ "77": 7778,
459
+ "78": 7275,
460
+ "79": 6728,
461
+ "80": 6557,
462
+ "81": 6062,
463
+ "82": 5907,
464
+ "83": 5520,
465
+ "84": 5272,
466
+ "85": 4972,
467
+ "86": 4439,
468
+ "87": 3988,
469
+ "88": 3607,
470
+ "89": 3342,
471
+ "90": 3260,
472
+ "91": 3148,
473
+ "92": 2978,
474
+ "93": 3015,
475
+ "94": 2783,
476
+ "95": 2642,
477
+ "96": 2436,
478
+ "97": 2283,
479
+ "98": 2134,
480
+ "99": 2055,
481
+ "100": 1914,
482
+ "101": 1877,
483
+ "102": 1641,
484
+ "103": 1643,
485
+ "104": 1537,
486
+ "105": 1521,
487
+ "106": 1459,
488
+ "107": 1329,
489
+ "108": 1227,
490
+ "109": 1124,
491
+ "110": 1085,
492
+ "111": 1003,
493
+ "112": 967,
494
+ "113": 837,
495
+ "114": 748,
496
+ "115": 695,
497
+ "116": 680,
498
+ "117": 662,
499
+ "118": 590,
500
+ "119": 584,
501
+ "120": 596,
502
+ "121": 630,
503
+ "122": 608,
504
+ "123": 567,
505
+ "124": 549,
506
+ "125": 535,
507
+ "126": 485,
508
+ "127": 432,
509
+ "128": 387,
510
+ "129": 379,
511
+ "130": 390,
512
+ "131": 364,
513
+ "132": 288,
514
+ "133": 321,
515
+ "134": 302,
516
+ "135": 280,
517
+ "136": 268,
518
+ "137": 287,
519
+ "138": 270,
520
+ "139": 262,
521
+ "140": 222,
522
+ "141": 196,
523
+ "142": 170,
524
+ "143": 136,
525
+ "144": 155,
526
+ "145": 122,
527
+ "146": 115,
528
+ "147": 114,
529
+ "148": 96,
530
+ "149": 98,
531
+ "150": 83,
532
+ "151": 94,
533
+ "152": 94,
534
+ "153": 84,
535
+ "154": 77,
536
+ "155": 88,
537
+ "156": 70,
538
+ "157": 66,
539
+ "158": 60,
540
+ "159": 78,
541
+ "160": 59,
542
+ "161": 57,
543
+ "162": 63,
544
+ "163": 74,
545
+ "164": 63,
546
+ "165": 52,
547
+ "166": 65,
548
+ "167": 50,
549
+ "168": 76,
550
+ "169": 63,
551
+ "170": 63,
552
+ "171": 67,
553
+ "172": 62,
554
+ "173": 47,
555
+ "174": 51,
556
+ "175": 38,
557
+ "176": 42,
558
+ "177": 44,
559
+ "178": 44,
560
+ "179": 39,
561
+ "180": 45,
562
+ "181": 42,
563
+ "182": 31,
564
+ "183": 27,
565
+ "184": 39,
566
+ "185": 21,
567
+ "186": 28,
568
+ "187": 23,
569
+ "188": 36,
570
+ "189": 24,
571
+ "190": 11,
572
+ "191": 11,
573
+ "192": 11,
574
+ "193": 6,
575
+ "194": 5,
576
+ "195": 1
577
+ }
578
+ }
counts/sha_max.json ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "4": {
3
+ "max": 4,
4
+ "name": [
5
+ "007.jpg"
6
+ ],
7
+ "x": 324,
8
+ "y": 176
9
+ },
10
+ "7": {
11
+ "max": 8,
12
+ "name": [
13
+ "034.jpg"
14
+ ],
15
+ "x": 271,
16
+ "y": 341
17
+ },
18
+ "8": {
19
+ "max": 10,
20
+ "name": [
21
+ "034.jpg"
22
+ ],
23
+ "x": 271,
24
+ "y": 340
25
+ },
26
+ "14": {
27
+ "max": 20,
28
+ "name": [
29
+ "120.jpg"
30
+ ],
31
+ "x": 295,
32
+ "y": 762
33
+ },
34
+ "16": {
35
+ "max": 28,
36
+ "name": [
37
+ "120.jpg"
38
+ ],
39
+ "x": 296,
40
+ "y": 760
41
+ },
42
+ "28": {
43
+ "max": 55,
44
+ "name": [
45
+ "120.jpg"
46
+ ],
47
+ "x": 303,
48
+ "y": 652
49
+ },
50
+ "32": {
51
+ "max": 66,
52
+ "name": [
53
+ "120.jpg"
54
+ ],
55
+ "x": 313,
56
+ "y": 651
57
+ },
58
+ "56": {
59
+ "max": 159,
60
+ "name": [
61
+ "120.jpg"
62
+ ],
63
+ "x": 301,
64
+ "y": 655
65
+ },
66
+ "64": {
67
+ "max": 195,
68
+ "name": [
69
+ "120.jpg"
70
+ ],
71
+ "x": 301,
72
+ "y": 657
73
+ }
74
+ }
counts/shb.json ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "1": {
3
+ "0": 314523695,
4
+ "1": 49105
5
+ },
6
+ "4": {
7
+ "0": 311650011,
8
+ "1": 772635,
9
+ "2": 3256,
10
+ "3": 95,
11
+ "4": 3
12
+ },
13
+ "7": {
14
+ "0": 308004073,
15
+ "1": 2221020,
16
+ "2": 55546,
17
+ "3": 4833,
18
+ "4": 681,
19
+ "5": 181,
20
+ "6": 53,
21
+ "7": 10,
22
+ "8": 3
23
+ },
24
+ "8": {
25
+ "0": 306646552,
26
+ "1": 2818806,
27
+ "2": 96449,
28
+ "3": 10596,
29
+ "4": 1733,
30
+ "5": 447,
31
+ "6": 138,
32
+ "7": 57,
33
+ "8": 22
34
+ },
35
+ "14": {
36
+ "0": 297434203,
37
+ "1": 7041868,
38
+ "2": 636791,
39
+ "3": 139674,
40
+ "4": 42083,
41
+ "5": 15292,
42
+ "6": 6324,
43
+ "7": 3122,
44
+ "8": 1298,
45
+ "9": 595,
46
+ "10": 304,
47
+ "11": 225,
48
+ "12": 169,
49
+ "13": 45,
50
+ "14": 7
51
+ },
52
+ "16": {
53
+ "0": 294072360,
54
+ "1": 8559657,
55
+ "2": 922301,
56
+ "3": 225017,
57
+ "4": 75708,
58
+ "5": 29428,
59
+ "6": 12533,
60
+ "7": 6347,
61
+ "8": 3429,
62
+ "9": 1869,
63
+ "10": 913,
64
+ "11": 494,
65
+ "12": 338,
66
+ "13": 202,
67
+ "14": 192,
68
+ "15": 11,
69
+ "16": 1
70
+ },
71
+ "28": {
72
+ "0": 272510235,
73
+ "1": 17410504,
74
+ "2": 3291284,
75
+ "3": 1142143,
76
+ "4": 507297,
77
+ "5": 259215,
78
+ "6": 143543,
79
+ "7": 86057,
80
+ "8": 52776,
81
+ "9": 33818,
82
+ "10": 22305,
83
+ "11": 14778,
84
+ "12": 9902,
85
+ "13": 6909,
86
+ "14": 4829,
87
+ "15": 3511,
88
+ "16": 2765,
89
+ "17": 2161,
90
+ "18": 1627,
91
+ "19": 1396,
92
+ "20": 1075,
93
+ "21": 796,
94
+ "22": 639,
95
+ "23": 520,
96
+ "24": 375,
97
+ "25": 205,
98
+ "26": 92,
99
+ "27": 27,
100
+ "28": 10,
101
+ "29": 4,
102
+ "30": 2
103
+ },
104
+ "32": {
105
+ "0": 265135522,
106
+ "1": 20054326,
107
+ "2": 4219708,
108
+ "3": 1561515,
109
+ "4": 730071,
110
+ "5": 382477,
111
+ "6": 224559,
112
+ "7": 137037,
113
+ "8": 88156,
114
+ "9": 58687,
115
+ "10": 40153,
116
+ "11": 27989,
117
+ "12": 19367,
118
+ "13": 13555,
119
+ "14": 10126,
120
+ "15": 7417,
121
+ "16": 5593,
122
+ "17": 4242,
123
+ "18": 3235,
124
+ "19": 2714,
125
+ "20": 2136,
126
+ "21": 1687,
127
+ "22": 1343,
128
+ "23": 1093,
129
+ "24": 990,
130
+ "25": 881,
131
+ "26": 651,
132
+ "27": 428,
133
+ "28": 278,
134
+ "29": 173,
135
+ "30": 116,
136
+ "31": 83,
137
+ "32": 43,
138
+ "33": 36,
139
+ "34": 8,
140
+ "35": 3,
141
+ "36": 2
142
+ },
143
+ "56": {
144
+ "0": 222314024,
145
+ "1": 32191189,
146
+ "2": 9727123,
147
+ "3": 4342794,
148
+ "4": 2404979,
149
+ "5": 1505427,
150
+ "6": 1000917,
151
+ "7": 701563,
152
+ "8": 499165,
153
+ "9": 362489,
154
+ "10": 267104,
155
+ "11": 199980,
156
+ "12": 153876,
157
+ "13": 123592,
158
+ "14": 98575,
159
+ "15": 80346,
160
+ "16": 63904,
161
+ "17": 48447,
162
+ "18": 40380,
163
+ "19": 33358,
164
+ "20": 28391,
165
+ "21": 24691,
166
+ "22": 21645,
167
+ "23": 17519,
168
+ "24": 14226,
169
+ "25": 11839,
170
+ "26": 10556,
171
+ "27": 8884,
172
+ "28": 7573,
173
+ "29": 6473,
174
+ "30": 5818,
175
+ "31": 4784,
176
+ "32": 4100,
177
+ "33": 4039,
178
+ "34": 3497,
179
+ "35": 2721,
180
+ "36": 2238,
181
+ "37": 2208,
182
+ "38": 2072,
183
+ "39": 2096,
184
+ "40": 1750,
185
+ "41": 1466,
186
+ "42": 1404,
187
+ "43": 1196,
188
+ "44": 1138,
189
+ "45": 918,
190
+ "46": 786,
191
+ "47": 672,
192
+ "48": 698,
193
+ "49": 688,
194
+ "50": 610,
195
+ "51": 537,
196
+ "52": 469,
197
+ "53": 448,
198
+ "54": 346,
199
+ "55": 264,
200
+ "56": 198,
201
+ "57": 168,
202
+ "58": 131,
203
+ "59": 54,
204
+ "60": 28,
205
+ "61": 34,
206
+ "62": 22,
207
+ "63": 10,
208
+ "64": 18,
209
+ "65": 17,
210
+ "66": 16,
211
+ "67": 25,
212
+ "68": 21,
213
+ "69": 25,
214
+ "70": 11,
215
+ "71": 13,
216
+ "72": 7,
217
+ "73": 2,
218
+ "74": 4,
219
+ "76": 4
220
+ },
221
+ "64": {
222
+ "0": 209048823,
223
+ "1": 34905056,
224
+ "2": 11413735,
225
+ "3": 5278103,
226
+ "4": 2980067,
227
+ "5": 1886714,
228
+ "6": 1308620,
229
+ "7": 945805,
230
+ "8": 684080,
231
+ "9": 516549,
232
+ "10": 387772,
233
+ "11": 301510,
234
+ "12": 234031,
235
+ "13": 186750,
236
+ "14": 149049,
237
+ "15": 124290,
238
+ "16": 101853,
239
+ "17": 81550,
240
+ "18": 68680,
241
+ "19": 55441,
242
+ "20": 45411,
243
+ "21": 39050,
244
+ "22": 33804,
245
+ "23": 30803,
246
+ "24": 24284,
247
+ "25": 20547,
248
+ "26": 17358,
249
+ "27": 14546,
250
+ "28": 12847,
251
+ "29": 11443,
252
+ "30": 9852,
253
+ "31": 8715,
254
+ "32": 7569,
255
+ "33": 6927,
256
+ "34": 6284,
257
+ "35": 5688,
258
+ "36": 4647,
259
+ "37": 4476,
260
+ "38": 3947,
261
+ "39": 3756,
262
+ "40": 3232,
263
+ "41": 2883,
264
+ "42": 2580,
265
+ "43": 2338,
266
+ "44": 2092,
267
+ "45": 1930,
268
+ "46": 1670,
269
+ "47": 1514,
270
+ "48": 1470,
271
+ "49": 1361,
272
+ "50": 1267,
273
+ "51": 1218,
274
+ "52": 939,
275
+ "53": 852,
276
+ "54": 738,
277
+ "55": 662,
278
+ "56": 628,
279
+ "57": 690,
280
+ "58": 495,
281
+ "59": 508,
282
+ "60": 441,
283
+ "61": 401,
284
+ "62": 333,
285
+ "63": 314,
286
+ "64": 194,
287
+ "65": 130,
288
+ "66": 108,
289
+ "67": 108,
290
+ "68": 91,
291
+ "69": 72,
292
+ "70": 32,
293
+ "71": 29,
294
+ "72": 32,
295
+ "73": 20,
296
+ "74": 17,
297
+ "75": 11,
298
+ "76": 21,
299
+ "77": 15,
300
+ "78": 17,
301
+ "79": 21,
302
+ "80": 20,
303
+ "81": 13,
304
+ "82": 17,
305
+ "83": 9,
306
+ "84": 8,
307
+ "85": 5,
308
+ "86": 10,
309
+ "87": 3,
310
+ "88": 5,
311
+ "89": 4
312
+ }
313
+ }
counts/shb_max.json ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "4": {
3
+ "max": 4,
4
+ "name": [
5
+ "148.jpg"
6
+ ],
7
+ "x": 40,
8
+ "y": 549
9
+ },
10
+ "7": {
11
+ "max": 8,
12
+ "name": [
13
+ "200.jpg"
14
+ ],
15
+ "x": 275,
16
+ "y": 37
17
+ },
18
+ "8": {
19
+ "max": 8,
20
+ "name": [
21
+ "148.jpg"
22
+ ],
23
+ "x": 39,
24
+ "y": 550
25
+ },
26
+ "14": {
27
+ "max": 14,
28
+ "name": [
29
+ "200.jpg"
30
+ ],
31
+ "x": 269,
32
+ "y": 37
33
+ },
34
+ "16": {
35
+ "max": 16,
36
+ "name": [
37
+ "191.jpg"
38
+ ],
39
+ "x": 1,
40
+ "y": 257
41
+ },
42
+ "28": {
43
+ "max": 30,
44
+ "name": [
45
+ "191.jpg"
46
+ ],
47
+ "x": 0,
48
+ "y": 257
49
+ },
50
+ "32": {
51
+ "max": 36,
52
+ "name": [
53
+ "191.jpg"
54
+ ],
55
+ "x": 0,
56
+ "y": 256
57
+ },
58
+ "56": {
59
+ "max": 76,
60
+ "name": [
61
+ "191.jpg"
62
+ ],
63
+ "x": 0,
64
+ "y": 256
65
+ },
66
+ "64": {
67
+ "max": 89,
68
+ "name": [
69
+ "191.jpg"
70
+ ],
71
+ "x": 1,
72
+ "y": 254
73
+ }
74
+ }
datasets/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .crowd import Crowd, InMemoryCrowd, available_datasets, standardize_dataset_name, NWPUTest, ShanghaiTech
2
+ from .transforms import RandomCrop, Resize, RandomResizedCrop, RandomHorizontalFlip, Resize2Multiple, ZeroPad2Multiple
3
+ from .transforms import ColorJitter, RandomGrayscale, GaussianBlur, RandomApply, PepperSaltNoise
4
+ from .utils import collate_fn
5
+
6
+
7
+ __all__ = [
8
+ "Crowd", "InMemoryCrowd", "available_datasets", "standardize_dataset_name", "NWPUTest", "ShanghaiTech",
9
+ "RandomCrop", "Resize", "RandomResizedCrop", "RandomHorizontalFlip", "Resize2Multiple", "ZeroPad2Multiple",
10
+ "ColorJitter", "RandomGrayscale", "GaussianBlur", "RandomApply", "PepperSaltNoise",
11
+ "collate_fn",
12
+ ]
datasets/crowd.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import Tensor
3
+ from torch.utils.data import Dataset
4
+ from torchvision.transforms import ToTensor, Normalize, Compose
5
+ import os
6
+ from glob import glob
7
+ from tqdm import tqdm
8
+ # from PIL import Image
9
+ from turbojpeg import TurboJPEG, TJPF_RGB
10
+ jpeg_decoder = TurboJPEG()
11
+
12
+ import numpy as np
13
+ from typing import Optional, Callable, Union, Tuple
14
+
15
+ from .utils import get_id, generate_density_map
16
+
17
+ curr_dir = os.path.dirname(os.path.abspath(__file__))
18
+
19
+ available_datasets = [
20
+ "shanghaitech_a", "sha",
21
+ "shanghaitech_b", "shb",
22
+ "shanghaitech", "sh",
23
+ "ucf_qnrf", "qnrf", "ucf-qnrf",
24
+ "nwpu", "nwpu_crowd", "nwpu-crowd",
25
+ ]
26
+
27
+ mean = (0.48145466, 0.4578275, 0.40821073)
28
+ std = (0.26862954, 0.26130258, 0.27577711)
29
+
30
+
31
+ def standardize_dataset_name(dataset: str) -> str:
32
+ assert dataset.lower() in available_datasets, f"Dataset {dataset} is not available."
33
+ if dataset.lower() in ["shanghaitech_a", "sha"]:
34
+ return "sha"
35
+ elif dataset.lower() in ["shanghaitech_b", "shb"]:
36
+ return "shb"
37
+ elif dataset.lower() in ["shanghaitech", "sh"]:
38
+ return "sh"
39
+ elif dataset.lower() in ["ucf_qnrf", "qnrf", "ucf-qnrf"]:
40
+ return "qnrf"
41
+ else:
42
+ assert dataset.lower() in ["nwpu", "nwpu_crowd", "nwpu-crowd"], f"Dataset {dataset} is not available."
43
+ return "nwpu"
44
+
45
+
46
+ class Crowd(Dataset):
47
+ def __init__(
48
+ self,
49
+ dataset: str,
50
+ split: str,
51
+ transforms: Optional[Callable] = None,
52
+ sigma: Optional[float] = None,
53
+ return_filename: bool = False,
54
+ num_crops: int = 1,
55
+ ) -> None:
56
+ """
57
+ Dataset for crowd counting.
58
+ """
59
+ assert dataset.lower() in available_datasets, f"Dataset {dataset} is not available."
60
+ assert dataset.lower() not in ["shanghaitech", "sh"], "For the combined ShanghaiTech dataset, use ShanghaiTech class."
61
+ assert split in ["train", "val", "test"], f"Split {split} is not available."
62
+ assert num_crops > 0, f"num_crops should be positive, got {num_crops}."
63
+
64
+ self.dataset = standardize_dataset_name(dataset)
65
+ self.split = split
66
+
67
+ self.__find_root__()
68
+ self.__make_dataset__()
69
+ self.__check_sanity__()
70
+
71
+ self.to_tensor = ToTensor()
72
+ self.normalize = Normalize(mean=mean, std=std)
73
+ self.transforms = transforms
74
+
75
+ self.sigma = sigma
76
+ self.return_filename = return_filename
77
+ self.num_crops = num_crops
78
+
79
+ def __find_root__(self) -> None:
80
+ self.root = os.path.join(curr_dir, "..", "data", self.dataset)
81
+
82
+ def __make_dataset__(self) -> None:
83
+ image_names = glob(os.path.join(self.root, self.split, "images", "*.jpg"))
84
+
85
+ label_names = glob(os.path.join(self.root, self.split, "labels", "*.npy"))
86
+ image_names = [os.path.basename(image_name) for image_name in image_names]
87
+ label_names = [os.path.basename(label_name) for label_name in label_names]
88
+ image_names.sort(key=get_id)
89
+ label_names.sort(key=get_id)
90
+ image_ids = tuple([get_id(image_name) for image_name in image_names])
91
+ label_ids = tuple([get_id(label_name) for label_name in label_names])
92
+ assert image_ids == label_ids, "image_ids and label_ids do not match."
93
+ self.image_names = tuple(image_names)
94
+ self.label_names = tuple(label_names)
95
+
96
+ def __check_sanity__(self) -> None:
97
+ if self.dataset == "sha":
98
+ if self.split == "train":
99
+ assert len(self.image_names) == len(self.label_names) == 300, f"ShanghaiTech_A train split should have 300 images, but found {len(self.image_names)}."
100
+ else:
101
+ assert self.split == "val", f"Split {self.split} is not available for dataset {self.dataset}."
102
+ assert len(self.image_names) == len(self.label_names) == 182, f"ShanghaiTech_A val split should have 182 images, but found {len(self.image_names)}."
103
+ elif self.dataset == "shb":
104
+ if self.split == "train":
105
+ assert len(self.image_names) == len(self.label_names) == 399, f"ShanghaiTech_B train split should have 399 images, but found {len(self.image_names)}."
106
+ else:
107
+ assert self.split == "val", f"Split {self.split} is not available for dataset {self.dataset}."
108
+ assert len(self.image_names) == len(self.label_names) == 316, f"ShanghaiTech_B val split should have 316 images, but found {len(self.image_names)}."
109
+ elif self.dataset == "nwpu":
110
+ if self.split == "train":
111
+ assert len(self.image_names) == len(self.label_names) == 3109, f"NWPU train split should have 3109 images, but found {len(self.image_names)}."
112
+ else:
113
+ assert self.split == "val", f"Split {self.split} is not available for dataset {self.dataset}."
114
+ assert len(self.image_names) == len(self.label_names) == 500, f"NWPU val split should have 500 images, but found {len(self.image_names)}."
115
+ elif self.dataset == "qnrf":
116
+ if self.split == "train":
117
+ assert len(self.image_names) == len(self.label_names) == 1201, f"UCF_QNRF train split should have 1201 images, but found {len(self.image_names)}."
118
+ else:
119
+ assert self.split == "val", f"Split {self.split} is not available for dataset {self.dataset}."
120
+ assert len(self.image_names) == len(self.label_names) == 334, f"UCF_QNRF val split should have 334 images, but found {len(self.image_names)}."
121
+
122
+ def __len__(self) -> int:
123
+ return len(self.image_names)
124
+
125
+ def __getitem__(self, idx: int) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[Tensor, Tensor, Tensor, str]]:
126
+ image_name = self.image_names[idx]
127
+ label_name = self.label_names[idx]
128
+
129
+ image_path = os.path.join(self.root, self.split, "images", image_name)
130
+ label_path = os.path.join(self.root, self.split, "labels", label_name)
131
+
132
+ with open(image_path, "rb") as f:
133
+ # image = Image.open(f).convert("RGB")
134
+ image = jpeg_decoder.decode(f.read(), pixel_format=TJPF_RGB)
135
+ image = self.to_tensor(image)
136
+
137
+ with open(label_path, "rb") as f:
138
+ label = np.load(f)
139
+ label = torch.from_numpy(label).float()
140
+
141
+ if self.transforms is not None:
142
+ images_labels = [self.transforms(image.clone(), label.clone()) for _ in range(self.num_crops)]
143
+ images, labels = zip(*images_labels)
144
+ else:
145
+ images = [image.clone() for _ in range(self.num_crops)]
146
+ labels = [label.clone() for _ in range(self.num_crops)]
147
+
148
+ images = [self.normalize(img) for img in images]
149
+ density_maps = torch.stack([generate_density_map(label, image.shape[-2], image.shape[-1], sigma=self.sigma) for image, label in zip(images, labels)], 0)
150
+ image_names = [image_name] * len(images)
151
+ images = torch.stack(images, 0)
152
+
153
+ if self.return_filename:
154
+ return images, labels, density_maps, image_names
155
+ else:
156
+ return images, labels, density_maps
157
+
158
+
159
+ class InMemoryCrowd(Dataset):
160
+ def __init__(
161
+ self,
162
+ dataset: str,
163
+ split: str,
164
+ transforms: Optional[Callable] = None,
165
+ sigma: Optional[float] = None,
166
+ return_filename: bool = False,
167
+ num_crops: int = 1,
168
+ ) -> None:
169
+ """
170
+ Dataset for crowd counting, with images and labels loaded into memory.
171
+ """
172
+ crowd = Crowd(
173
+ dataset=dataset,
174
+ split=split,
175
+ transforms=None,
176
+ sigma=sigma,
177
+ return_filename=True,
178
+ num_crops=1,
179
+ )
180
+ print(f"Loading {len(crowd)} samples from {dataset} {split} split into memory...")
181
+ self.images, self.labels, self.image_names = [], [], []
182
+ self.unnormalize = Compose([
183
+ Normalize(mean=(0., 0., 0.), std=(1./std[0], 1./std[1], 1./std[2]), inplace=True),
184
+ Normalize(mean=(-mean[0], -mean[1], -mean[2]), std=(1., 1., 1.), inplace=True)
185
+ ])
186
+
187
+ for i in tqdm(range(len(crowd)), desc="Loading images and labels into memory"):
188
+ image, label, _, image_name = crowd[i]
189
+ self.images.append(self.unnormalize(image[0])) # recover original image
190
+ self.labels.append(label[0])
191
+ self.image_names.append(image_name[0])
192
+
193
+ assert len(self.images) == len(self.labels) == len(self.image_names), "Mismatch in number of images, labels, and image names."
194
+
195
+ self.transforms = transforms
196
+ self.sigma = sigma
197
+ self.num_crops = num_crops
198
+ self.return_filename = return_filename
199
+ self.normalize = Normalize(mean=mean, std=std, inplace=False)
200
+
201
+ def __len__(self) -> int:
202
+ return len(self.images)
203
+
204
+ def __getitem__(self, idx: int) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[Tensor, Tensor, Tensor, str]]:
205
+ image, label, image_name = self.images[idx].clone(), self.labels[idx].clone(), self.image_names[idx]
206
+
207
+ if self.transforms is not None:
208
+ images_labels = [self.transforms(image.clone(), label.clone()) for _ in range(self.num_crops)]
209
+ images, labels = zip(*images_labels)
210
+ else:
211
+ images = [image.clone() for _ in range(self.num_crops)]
212
+ labels = [label.clone() for _ in range(self.num_crops)]
213
+
214
+ images = [self.normalize(img) for img in images]
215
+ density_maps = torch.stack([generate_density_map(label, image.shape[-2], image.shape[-1], sigma=self.sigma) for image, label in zip(images, labels)], 0)
216
+ image_names = [image_name] * len(images)
217
+ images = torch.stack(images, 0)
218
+
219
+ if self.return_filename:
220
+ return images, labels, density_maps, image_names
221
+ else:
222
+ return images, labels, density_maps
223
+
224
+
225
+ class NWPUTest(Dataset):
226
+ def __init__(
227
+ self,
228
+ transforms: Optional[Callable] = None,
229
+ return_filename: bool = False,
230
+ ) -> None:
231
+ """
232
+ The test set of NWPU-Crowd dataset. The test set is not labeled, so only images are returned.
233
+ """
234
+ self.root = os.path.join(curr_dir, "..", "data", "nwpu")
235
+ image_names = glob(os.path.join(self.root, "test", "images", "*.jpg"))
236
+
237
+ image_names = [os.path.basename(image_name) for image_name in image_names]
238
+ assert len(image_names) == 1500, f"NWPU test split should have 1500 images, but found {len(image_names)}."
239
+ image_names.sort(key=get_id)
240
+ self.image_names = tuple(image_names)
241
+
242
+ self.to_tensor = ToTensor()
243
+ self.normalize = Normalize(mean=mean, std=std)
244
+ self.transforms = transforms
245
+ self.return_filename = return_filename
246
+
247
+ def __len__(self) -> int:
248
+ return len(self.image_names)
249
+
250
+ def __getitem__(self, idx: int) -> Union[Tensor, Tuple[Tensor, str]]:
251
+ image_name = self.image_names[idx]
252
+ image_path = os.path.join(self.root, "test", "images", image_name)
253
+
254
+ with open(image_path, "rb") as f:
255
+ # image = Image.open(f).convert("RGB")
256
+ image = jpeg_decoder.decode(f.read(), pixel_format=TJPF_RGB)
257
+ image = self.to_tensor(image)
258
+
259
+ label = torch.tensor([], dtype=torch.float) # dummy label
260
+ image, _ = self.transforms(image, label) if self.transforms is not None else (image, label)
261
+ image = self.normalize(image)
262
+
263
+ if self.return_filename:
264
+ return image, image_name
265
+ else:
266
+ return image
267
+
268
+
269
+ class ShanghaiTech(Dataset):
270
+ def __init__(
271
+ self,
272
+ split: str,
273
+ transforms: Optional[Callable] = None,
274
+ sigma: Optional[float] = None,
275
+ return_filename: bool = False,
276
+ num_crops: int = 1,
277
+ ) -> None:
278
+ super().__init__()
279
+ self.sha = Crowd(
280
+ dataset="sha",
281
+ split=split,
282
+ transforms=transforms,
283
+ sigma=sigma,
284
+ return_filename=return_filename,
285
+ num_crops=num_crops,
286
+ )
287
+ self.shb = Crowd(
288
+ dataset="shb",
289
+ split=split,
290
+ transforms=transforms,
291
+ sigma=sigma,
292
+ return_filename=return_filename,
293
+ num_crops=num_crops,
294
+ )
295
+ self.dataset = "sh"
296
+ self.split = split
297
+ self.transforms = transforms
298
+ self.sigma = sigma
299
+ self.return_filename = return_filename
300
+ self.num_crops = num_crops
301
+
302
+ def __len__(self) -> int:
303
+ return len(self.sha) + len(self.shb)
304
+
305
+ def __getitem__(self, idx: int) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[Tensor, Tensor, Tensor, str]]:
306
+ if idx < len(self.sha):
307
+ return self.sha[idx]
308
+ else:
309
+ return self.shb[idx - len(self.sha)]
datasets/transforms.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import Tensor
3
+ from torchvision.transforms import ColorJitter as _ColorJitter
4
+ import torchvision.transforms.functional as TF
5
+ import numpy as np
6
+ from typing import Tuple, Union, Optional, Callable
7
+
8
+
9
+ def _crop(
10
+ image: Tensor,
11
+ label: Tensor,
12
+ top: int,
13
+ left: int,
14
+ height: int,
15
+ width: int,
16
+ ) -> Tuple[Tensor, Tensor]:
17
+ image = TF.crop(image, top, left, height, width)
18
+ if len(label) > 0:
19
+ label[:, 0] -= left
20
+ label[:, 1] -= top
21
+ label_mask = (label[:, 0] >= 0) & (label[:, 0] < width) & (label[:, 1] >= 0) & (label[:, 1] < height)
22
+ label = label[label_mask]
23
+
24
+ return image, label
25
+
26
+
27
+ def _resize(
28
+ image: Tensor,
29
+ label: Tensor,
30
+ height: int,
31
+ width: int,
32
+ ) -> Tuple[Tensor, Tensor]:
33
+ image_height, image_width = image.shape[-2:]
34
+ image = TF.resize(image, (height, width), interpolation=TF.InterpolationMode.BICUBIC, antialias=True) if (image_height != height or image_width != width) else image
35
+ if len(label) > 0 and (image_height != height or image_width != width):
36
+ label[:, 0] = label[:, 0] * width / image_width
37
+ label[:, 1] = label[:, 1] * height / image_height
38
+ label[:, 0] = label[:, 0].clamp(min=0, max=width - 1)
39
+ label[:, 1] = label[:, 1].clamp(min=0, max=height - 1)
40
+
41
+ return image, label
42
+
43
+
44
+ class RandomCrop(object):
45
+ def __init__(self, size: Tuple[int, int]) -> None:
46
+ self.size = size
47
+ assert len(self.size) == 2, f"size should be a tuple (h, w), got {self.size}."
48
+
49
+ def __call__(self, image: Tensor, label: Tensor) -> Tuple[Tensor, Tensor]:
50
+ crop_height, crop_width = self.size
51
+ image_height, image_width = image.shape[-2:]
52
+ assert crop_height <= image_height and crop_width <= image_width, \
53
+ f"crop size should be no larger than image size, got crop size {self.size} and image size {image.shape}."
54
+
55
+ top = torch.randint(0, image_height - crop_height + 1, (1,)).item()
56
+ left = torch.randint(0, image_width - crop_width + 1, (1,)).item()
57
+ return _crop(image, label, top, left, crop_height, crop_width)
58
+
59
+
60
+ class Resize(object):
61
+ def __init__(self, size: Tuple[int, int]) -> None:
62
+ self.size = size
63
+ assert len(self.size) == 2, f"size should be a tuple (h, w), got {self.size}."
64
+
65
+ def __call__(self, image: Tensor, label: Tensor) -> Tuple[Tensor, Tensor]:
66
+ return _resize(image, label, self.size[0], self.size[1])
67
+
68
+
69
+ class Resize2Multiple(object):
70
+ """
71
+ Resize the image so that it satisfies:
72
+ img_h = window_h + stride_h * n_h
73
+ img_w = window_w + stride_w * n_w
74
+ """
75
+ def __init__(
76
+ self,
77
+ window_size: Tuple[int, int],
78
+ stride: Tuple[int, int],
79
+ ) -> None:
80
+ window_size = (int(window_size), int(window_size)) if isinstance(window_size, (int, float)) else window_size
81
+ window_size = tuple(window_size)
82
+ stride = (int(stride), int(stride)) if isinstance(stride, (int, float)) else stride
83
+ stride = tuple(stride)
84
+ assert len(window_size) == 2, f"window_size should be a tuple (h, w), got {window_size}."
85
+ assert len(stride) == 2, f"stride should be a tuple (h, w), got {stride}."
86
+ assert all(s > 0 for s in window_size), f"window_size should be positive, got {window_size}."
87
+ assert all(s > 0 for s in stride), f"stride should be positive, got {stride}."
88
+ assert stride[0] <= window_size[0] and stride[1] <= window_size[1], f"stride should be no larger than window_size, got {stride} and {window_size}."
89
+ self.window_size = window_size
90
+ self.stride = stride
91
+
92
+ def __call__(self, image: Tensor, label: Tensor) -> Tuple[Tensor, Tensor]:
93
+ image_height, image_width = image.shape[-2:]
94
+ window_height, window_width = self.window_size
95
+ stride_height, stride_width = self.stride
96
+ new_height = int(max(round((image_height - window_height) / stride_height), 0) * stride_height + window_height)
97
+ new_width = int(max(round((image_width - window_width) / stride_width), 0) * stride_width + window_width)
98
+
99
+ if new_height == image_height and new_width == image_width:
100
+ return image, label
101
+ else:
102
+ return _resize(image, label, new_height, new_width)
103
+
104
+
105
+ class ZeroPad2Multiple(object):
106
+ def __init__(
107
+ self,
108
+ window_size: Tuple[int, int],
109
+ stride: Tuple[int, int],
110
+ ) -> None:
111
+ window_size = (int(window_size), int(window_size)) if isinstance(window_size, (int, float)) else window_size
112
+ window_size = tuple(window_size)
113
+ stride = (int(stride), int(stride)) if isinstance(stride, (int, float)) else stride
114
+ stride = tuple(stride)
115
+ assert len(window_size) == 2, f"window_size should be a tuple (h, w), got {window_size}."
116
+ assert len(stride) == 2, f"stride should be a tuple (h, w), got {stride}."
117
+ assert all(s > 0 for s in window_size), f"window_size should be positive, got {window_size}."
118
+ assert all(s > 0 for s in stride), f"stride should be positive, got {stride}."
119
+ assert stride[0] <= window_size[0] and stride[1] <= window_size[1], f"stride should be no larger than window_size, got {stride} and {window_size}."
120
+ self.window_size = window_size
121
+ self.stride = stride
122
+
123
+ def __call__(self, image: Tensor, label: Tensor) -> Tuple[Tensor, Tensor]:
124
+ image_height, image_width = image.shape[-2:]
125
+ window_height, window_width = self.window_size
126
+ stride_height, stride_width = self.stride
127
+ new_height = int(max(np.ceil((image_height - window_height) / stride_height), 0) * stride_height + window_height)
128
+ new_width = int(max(np.ceil((image_width - window_width) / stride_width), 0) * stride_width + window_width)
129
+
130
+ if new_height == image_height and new_width == image_width:
131
+ return image, label
132
+ else:
133
+ assert new_height >= image_height and new_width >= image_width, f"new size should be no less than the original size, got {new_height} and {new_width}."
134
+ pad_height, pad_width = new_height - image_height, new_width - image_width
135
+ return TF.pad(image, (0, 0, pad_width, pad_height), fill=0), label # only pad the right and bottom sides so that the label coordinates are not affected
136
+
137
+
138
+ class RandomResizedCrop(object):
139
+ def __init__(
140
+ self,
141
+ size: Tuple[int, int],
142
+ scale: Tuple[float, float] = (0.75, 1.25),
143
+ ) -> None:
144
+ """
145
+ Randomly crop an image and resize it to a given size. The aspect ratio is preserved during this process.
146
+ """
147
+ self.size = size
148
+ self.scale = scale
149
+ assert len(self.size) == 2, f"size should be a tuple (h, w), got {self.size}."
150
+ assert 0 < self.scale[0] <= self.scale[1], f"scale should satisfy 0 < scale[0] <= scale[1], got {self.scale}."
151
+
152
+ def __call__(self, image: Tensor, label: Tensor) -> Tuple[Tensor, Tensor]:
153
+ out_height, out_width = self.size
154
+ # out_ratio = out_width / out_height
155
+
156
+ scale = torch.empty(1).uniform_(self.scale[0], self.scale[1]).item() # if scale < 1, then the image will be zoomed in, otherwise zoomed out
157
+ in_height, in_width = image.shape[-2:]
158
+
159
+ # if in_width / in_height < out_ratio: # Image is too tall
160
+ # crop_width = int(in_width * scale)
161
+ # crop_height = int(crop_width / out_ratio)
162
+ # else: # Image is too wide
163
+ # crop_height = int(in_height * scale)
164
+ # crop_width = int(crop_height * out_ratio)
165
+
166
+ crop_height, crop_width = int(out_height * scale), int(out_width * scale)
167
+
168
+ if crop_height <= in_height and crop_width <= in_width: # directly crop and resize the image
169
+ top = torch.randint(0, in_height - crop_height + 1, (1,)).item()
170
+ left = torch.randint(0, in_width - crop_width + 1, (1,)).item()
171
+
172
+ else: # resize the image and then crop
173
+ ratio = max(crop_height / in_height, crop_width / in_width) # keep the aspect ratio
174
+ resize_height, resize_width = int(in_height * ratio) + 1, int(in_width * ratio) + 1 # add 1 to make sure the resized image is no less than the crop size
175
+ image, label = _resize(image, label, resize_height, resize_width)
176
+
177
+ top = torch.randint(0, resize_height - crop_height + 1, (1,)).item()
178
+ left = torch.randint(0, resize_width - crop_width + 1, (1,)).item()
179
+
180
+ image, label = _crop(image, label, top, left, crop_height, crop_width)
181
+ return _resize(image, label, out_height, out_width)
182
+
183
+
184
+ class RandomHorizontalFlip(object):
185
+ def __init__(self, p: float = 0.5) -> None:
186
+ self.p = p
187
+ assert 0 <= self.p <= 1, f"p should be in range [0, 1], got {self.p}."
188
+
189
+ def __call__(self, image: Tensor, label: Tensor) -> Tuple[Tensor, Tensor]:
190
+ if torch.rand(1) < self.p:
191
+ image = TF.hflip(image)
192
+
193
+ if len(label) > 0:
194
+ label[:, 0] = image.shape[-1] - 1 - label[:, 0] # if width is 256, then 0 -> 255, 1 -> 254, 2 -> 253, etc.
195
+ label[:, 0] = label[:, 0].clamp(min=0, max=image.shape[-1] - 1)
196
+
197
+ return image, label
198
+
199
+
200
+ class ColorJitter(object):
201
+ def __init__(
202
+ self,
203
+ brightness: Union[float, Tuple[float, float]] = 0.4,
204
+ contrast: Union[float, Tuple[float, float]] = 0.4,
205
+ saturation: Union[float, Tuple[float, float]] = 0.4,
206
+ hue: Union[float, Tuple[float, float]] = 0.2,
207
+ ) -> None:
208
+ self.color_jitter = _ColorJitter(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue)
209
+
210
+ def __call__(self, image: Tensor, label: Tensor) -> Tuple[Tensor, Tensor]:
211
+ return self.color_jitter(image), label
212
+
213
+
214
+ class RandomGrayscale(object):
215
+ def __init__(self, p: float = 0.1) -> None:
216
+ self.p = p
217
+ assert 0 <= self.p <= 1, f"p should be in range [0, 1], got {self.p}."
218
+
219
+ def __call__(self, image: Tensor, label: Tensor) -> Tuple[Tensor, Tensor]:
220
+ if torch.rand(1) < self.p:
221
+ image = TF.rgb_to_grayscale(image, num_output_channels=3)
222
+
223
+ return image, label
224
+
225
+
226
+ class GaussianBlur(object):
227
+ def __init__(self, kernel_size: int, sigma: Tuple[float, float] = (0.1, 2.0)) -> None:
228
+ self.kernel_size = kernel_size
229
+ self.sigma = sigma
230
+
231
+ def __call__(self, image: Tensor, label: Tensor) -> Tuple[Tensor, Tensor]:
232
+ return TF.gaussian_blur(image, self.kernel_size, self.sigma), label
233
+
234
+
235
+ class RandomApply(object):
236
+ def __init__(self, transforms: Tuple[Callable, ...], p: Union[float, Tuple[float, ...]] = 0.5) -> None:
237
+ self.transforms = transforms
238
+ p = [p] * len(transforms) if isinstance(p, float) else p
239
+ assert all(0 <= p_ <= 1 for p_ in p), f"p should be in range [0, 1], got {p}."
240
+ assert len(p) == len(transforms), f"p should be a float or a tuple of floats with the same length as transforms, got {p}."
241
+ self.p = p
242
+
243
+ def __call__(self, image: Tensor, label: Tensor) -> Tuple[Tensor, Tensor]:
244
+ for transform, p in zip(self.transforms, self.p):
245
+ if torch.rand(1) < p:
246
+ image, label = transform(image, label)
247
+
248
+ return image, label
249
+
250
+
251
+ class PepperSaltNoise(object):
252
+ def __init__(self, saltiness: float = 0.001, spiciness: float = 0.001) -> None:
253
+ self.saltiness = saltiness
254
+ self.spiciness = spiciness
255
+ assert 0 <= self.saltiness <= 1, f"saltiness should be in range [0, 1], got {self.saltiness}."
256
+ assert 0 <= self.spiciness <= 1, f"spiciness should be in range [0, 1], got {self.spiciness}."
257
+
258
+ def __call__(self, image: Tensor, label: Tensor) -> Tuple[Tensor, Tensor]:
259
+ noise = torch.rand_like(image)
260
+ image = torch.where(noise < self.saltiness, 1., image) # Salt
261
+ image = torch.where(noise > 1 - self.spiciness, 0., image) # Pepper
262
+ return image, label
datasets/utils.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import Tensor
3
+ from scipy.ndimage import gaussian_filter
4
+ from typing import Optional, List, Tuple
5
+
6
+
7
+ def get_id(x: str) -> int:
8
+ return int(x.split(".")[0])
9
+
10
+
11
+ def generate_density_map(label: Tensor, height: int, width: int, sigma: Optional[float] = None) -> Tensor:
12
+ """
13
+ Generate the density map based on the dot annotations provided by the label.
14
+ """
15
+ density_map = torch.zeros((1, height, width), dtype=torch.float32)
16
+
17
+ if len(label) > 0:
18
+ assert len(label.shape) == 2 and label.shape[1] == 2, f"label should be a Nx2 tensor, got {label.shape}."
19
+ label_ = label.long()
20
+ label_[:, 0] = label_[:, 0].clamp(min=0, max=width - 1)
21
+ label_[:, 1] = label_[:, 1].clamp(min=0, max=height - 1)
22
+ density_map[0, label_[:, 1], label_[:, 0]] = 1.0
23
+
24
+ if sigma is not None:
25
+ assert sigma > 0, f"sigma should be positive if not None, got {sigma}."
26
+ density_map = torch.from_numpy(gaussian_filter(density_map, sigma=sigma))
27
+
28
+ return density_map
29
+
30
+
31
+ def collate_fn(batch: List[Tensor]) -> Tuple[Tensor, List[Tensor], Tensor]:
32
+ batch = list(zip(*batch))
33
+ images = batch[0]
34
+ assert len(images[0].shape) == 4, f"images should be a 4D tensor, got {images[0].shape}."
35
+ if len(batch) == 4: # image, label, density_map, image_name
36
+ images = torch.cat(images, 0)
37
+ points = batch[1] # list of lists of tensors, flatten it
38
+ points = [p for points_ in points for p in points_]
39
+ densities = torch.cat(batch[2], 0)
40
+ image_names = batch[3] # list of lists of strings, flatten it
41
+ image_names = [name for names_ in image_names for name in names_]
42
+
43
+ return images, points, densities, image_names
44
+
45
+ elif len(batch) == 3: # image, label, density_map
46
+ images = torch.cat(images, 0)
47
+ points = batch[1]
48
+ points = [p for points_ in points for p in points_]
49
+ densities = torch.cat(batch[2], 0)
50
+
51
+ return images, points, densities
52
+
53
+ elif len(batch) == 2: # image, image_name. NWPU test dataset
54
+ images = torch.cat(images, 0)
55
+ image_names = batch[1]
56
+ image_names = [name for names_ in image_names for name in names_]
57
+
58
+ return images, image_names
59
+
60
+ else:
61
+ images = torch.cat(images, 0)
62
+
63
+ return images
efficiency.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import ArgumentParser
2
+ import time
3
+ import os
4
+ import torch
5
+ import torchvision.transforms as transforms
6
+ from contextlib import nullcontext
7
+ import json
8
+ from models import get_model
9
+
10
+
11
+ parser = ArgumentParser(description="Train an EBC model.")
12
+ parser.add_argument("--model_info_path", type=str, required=True, help="Path to the model information file.")
13
+
14
+ parser.add_argument("--batch_size", type=int, default=1, help="Batch size for the model.")
15
+ parser.add_argument("--height", type=int, default=768, help="Height of the input image.")
16
+ parser.add_argument("--width", type=int, default=1024, help="Width of the input image.")
17
+
18
+ parser.add_argument("--num_iterations", type=int, default=200, help="Number of iterations to run the model.")
19
+ parser.add_argument("--num_warmup", type=int, default=20, help="Dispose of the first N iterations.")
20
+
21
+ parser.add_argument("--device", type=str, choices=["cpu", "cuda", "mps"], help="Device to run the model on. Options are 'cpu', 'cuda', or 'mps'.")
22
+ parser.add_argument("--amp", action="store_true", help="Enable autocast mixed precision (fp16/bf16).")
23
+ parser.add_argument("--half", action="store_true", help="Use half precision for the model.")
24
+ parser.add_argument("--channels_last", action="store_true", help="Use NHWC memory format (recommended for CUDA).")
25
+ parser.add_argument("--compile", action="store_true", help="Enable torch.compile if available.")
26
+ parser.add_argument("--threads", type=int, default=None, help="torch.set_num_threads(threads) for CPU")
27
+ parser.add_argument("--sleep_time", type=float, default=0.0, help="Seconds to sleep after *each* iteration (cool-down).")
28
+
29
+ _normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
30
+
31
+
32
+ def _dummy_input(bs, h, w, device, half, channels_last):
33
+ x = torch.rand(bs, 3, h, w, device=device)
34
+ x = _normalize(x)
35
+ if half:
36
+ x = x.half()
37
+ if channels_last:
38
+ x = x.to(memory_format=torch.channels_last)
39
+ return x
40
+
41
+
42
+ def _maybe_sync(dev):
43
+ if dev.type == "cuda":
44
+ torch.cuda.synchronize()
45
+
46
+
47
+ @torch.inference_mode()
48
+ def benchmark(
49
+ model: torch.nn.Module,
50
+ inp: torch.Tensor,
51
+ warmup: int,
52
+ steps: int,
53
+ amp: bool,
54
+ sleep_time: float = 0.0
55
+ ):
56
+ cm = torch.autocast(device_type=inp.device.type) if amp else nullcontext()
57
+
58
+ # --- warm-up ---
59
+ for _ in range(warmup):
60
+ with cm:
61
+ _ = model(inp)
62
+ _maybe_sync(inp.device)
63
+
64
+ # --- timed loop ---
65
+ total_time = 0.0
66
+ for _ in range(steps):
67
+ tic = time.perf_counter()
68
+ with cm:
69
+ _ = model(inp)
70
+
71
+ toc = time.perf_counter()
72
+ total_time += toc - tic
73
+
74
+ if sleep_time > 0:
75
+ time.sleep(sleep_time)
76
+
77
+ _maybe_sync(inp.device)
78
+
79
+ fps = steps / total_time
80
+ return fps, total_time / steps
81
+
82
+
83
+ def main(args):
84
+ assert os.path.isfile(args.model_info_path), \
85
+ f"{args.model_info_path} not found"
86
+
87
+ model = get_model(model_info_path=args.model_info_path)
88
+ model.eval()
89
+
90
+ if args.channels_last:
91
+ model = model.to(memory_format=torch.channels_last)
92
+ if args.half:
93
+ model = model.half()
94
+
95
+ device = torch.device(args.device)
96
+ model = model.to(device)
97
+
98
+ if args.compile and hasattr(torch, "compile"):
99
+ model = torch.compile(model, mode="reduce-overhead")
100
+
101
+ if args.threads:
102
+ torch.set_num_threads(args.threads)
103
+ torch.set_num_interop_threads(1)
104
+
105
+ inp = _dummy_input(
106
+ args.batch_size,
107
+ args.height,
108
+ args.width,
109
+ device,
110
+ args.half,
111
+ args.channels_last
112
+ )
113
+
114
+ fps, t_avg = benchmark(
115
+ model,
116
+ inp,
117
+ warmup=args.num_warmup,
118
+ steps=args.num_iterations,
119
+ amp=args.amp,
120
+ sleep_time=args.sleep_time
121
+ )
122
+
123
+ cfg = vars(args)
124
+ cfg.pop("model_info_path")
125
+ print(json.dumps(cfg, indent=2))
126
+ print(f"\nAverage latency: {t_avg*1000:6.2f} ms | FPS: {fps:,.2f}")
127
+
128
+
129
+ if __name__ == "__main__":
130
+ main(parser.parse_args())
131
+
132
+
133
+ # CUDA @FP16 + channels_last + torch.compile
134
+ # python efficiency.py \
135
+ # --model_info_path checkpoints/shb/ebc_p/best_mae.pth \
136
+ # --device cuda --half --amp --channels_last --compile
137
+
138
+ # CUDA @AMP + channels_last + torch.compile
139
+ # python efficiency.py \
140
+ # --model_info_path checkpoints/shb/ebc_p/best_mae.pth \
141
+ # --device cuda --amp --channels_last --compile
142
+
143
+ # CUDA @FP32 + channels_last + torch.compile
144
+ # python efficiency.py \
145
+ # --model_info_path checkpoints/shb/ebc_p/best_mae.pth \
146
+ # --device cuda --channels_last --compile
147
+
148
+ # AMD 5900X (12 Core) + channels_last + torch.compile
149
+ # export OMP_NUM_THREADS=12; export MKL_NUM_THREADS=12
150
+ # python efficiency.py \
151
+ # --model_info_path checkpoints/shb/ebc_p/best_mae.pth \
152
+ # --device cpu --threads 12 --channels_last --compile
153
+
154
+ # Apple M1 Pro (6 Performance Cores). Compiling makes it slower.
155
+ # export OMP_NUM_THREADS=6; export VECLIB_MAXIMUM_THREADS=6
156
+ # python efficiency.py \
157
+ # --model_info_path checkpoints/shb/ebc_p/best_mae.pth \
158
+ # --device cpu --threads 6
159
+
160
+ # Apple M1 Pro MPS @FP32 + torch.compile
161
+ # python efficiency.py \
162
+ # --model_info_path checkpoints/shb/ebc_p/best_mae.pth \
163
+ # --device mps --channels_last --compile
evaluate.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.amp import autocast
3
+ import torch.nn.functional as F
4
+ import torch.distributed as dist
5
+ from torch import nn, Tensor
6
+ from torch.utils.data import DataLoader
7
+ from typing import Tuple, Optional
8
+ from tqdm import tqdm
9
+ import numpy as np
10
+
11
+ from utils import sliding_window_predict, barrier, calculate_errors
12
+
13
+
14
+ def evaluate(
15
+ model: nn.Module,
16
+ data_loader: DataLoader,
17
+ sliding_window: bool,
18
+ max_input_size: int = 4096,
19
+ window_size: int = 224,
20
+ stride: int = 224,
21
+ max_num_windows: int = 64,
22
+ device: torch.device = torch.device("cuda"),
23
+ amp: bool = False,
24
+ local_rank: int = 0,
25
+ nprocs: int = 1,
26
+ progress_bar: bool = True,
27
+ ) -> Tuple[Tensor, Tensor]:
28
+ ddp = nprocs > 1
29
+ model = model.to(device)
30
+ model.eval()
31
+ pred_counts, gt_counts = [], []
32
+ data_iter = tqdm(data_loader) if (local_rank == 0 and progress_bar) else data_loader
33
+
34
+ for image, gt_points, _ in data_iter:
35
+ image = image.to(device)
36
+ image_height, image_width = image.shape[-2:]
37
+ gt_counts.extend([len(p) for p in gt_points])
38
+
39
+ # Resize image if it's smaller than the window size
40
+ aspect_ratio = image_width / image_height
41
+ if image_height < window_size:
42
+ new_height = window_size
43
+ new_width = int(new_height * aspect_ratio)
44
+ image = F.interpolate(image, size=(new_height, new_width), mode="bicubic", align_corners=False)
45
+ image_height, image_width = new_height, new_width
46
+ if image_width < window_size:
47
+ new_width = window_size
48
+ new_height = int(new_width / aspect_ratio)
49
+ image = F.interpolate(image, size=(new_height, new_width), mode="bicubic", align_corners=False)
50
+ image_height, image_width = new_height, new_width
51
+
52
+ with torch.set_grad_enabled(False), autocast(device_type="cuda", enabled=amp):
53
+ if sliding_window or (image_height * image_width) > max_input_size ** 2:
54
+ pred_den_maps = sliding_window_predict(model, image, window_size, stride, max_num_windows)
55
+ else:
56
+ pred_den_maps = model(image)
57
+
58
+ pred_counts.extend(pred_den_maps.sum(dim=(-1, -2, -3)).cpu().numpy().tolist())
59
+
60
+ barrier(ddp)
61
+ assert len(pred_counts) == len(gt_counts), f"Length of predictions and ground truths should be equal, but got {len(pred_counts)} and {len(gt_counts)}"
62
+
63
+ if ddp:
64
+ pred_counts, gt_counts = torch.tensor(pred_counts, device=device), torch.tensor(gt_counts, device=device)
65
+ # Pad `pred_counts` and `gt_counts` to the same length across all processes.
66
+ local_length = torch.tensor([len(pred_counts)], device=device)
67
+ lengths = [torch.zeros_like(local_length) for _ in range(nprocs)]
68
+ dist.all_gather(lengths, local_length)
69
+ max_length = max([l.item() for l in lengths])
70
+ padded_pred_counts, padded_gt_counts = torch.full((max_length,), float("nan"), device=device), torch.full((max_length,), float("nan"), device=device)
71
+ padded_pred_counts[:len(pred_counts)], padded_gt_counts[:len(gt_counts)] = pred_counts, gt_counts
72
+ gathered_pred_counts, gathered_gt_counts = [torch.zeros_like(padded_pred_counts) for _ in range(nprocs)], [torch.zeros_like(padded_gt_counts) for _ in range(nprocs)]
73
+ dist.all_gather(gathered_pred_counts, padded_pred_counts)
74
+ dist.all_gather(gathered_gt_counts, padded_gt_counts)
75
+ # Concatenate predictions and ground truths from all processes and remove padding (nan values).
76
+ pred_counts, gt_counts = torch.cat(gathered_pred_counts).cpu(), torch.cat(gathered_gt_counts).cpu()
77
+ pred_counts, gt_counts = pred_counts[~torch.isnan(pred_counts)], gt_counts[~torch.isnan(gt_counts)]
78
+ pred_counts, gt_counts = pred_counts.numpy(), gt_counts.numpy()
79
+
80
+ else:
81
+ pred_counts, gt_counts = np.array(pred_counts), np.array(gt_counts)
82
+
83
+ torch.cuda.empty_cache()
84
+ return calculate_errors(pred_counts, gt_counts)
losses/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .loss import QuadLoss
2
+ from .bregman_pytorch import sinkhorn
3
+
4
+ __all__ = [
5
+ "QuadLoss",
6
+ "sinkhorn",
7
+ ]
losses/bregman_pytorch.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code modified from https://github.com/cvlab-stonybrook/DM-Count/blob/master/losses/bregman_pytorch.py
2
+ import torch
3
+ from torch.amp import autocast
4
+ from torch import Tensor
5
+ from typing import Union, Tuple, Dict
6
+
7
+ M_EPS = 1e-16
8
+
9
+
10
+ @torch.no_grad()
11
+ @autocast(device_type="cuda", enabled=True, dtype=torch.float32)
12
+ def sinkhorn(
13
+ a: Tensor,
14
+ b: Tensor,
15
+ C: Tensor,
16
+ reg: float = 1e-1,
17
+ maxIter: int = 1000,
18
+ stopThr: float = 1e-9,
19
+ verbose: bool = False,
20
+ log: bool = True,
21
+ eval_freq: int = 10,
22
+ print_freq: int = 200,
23
+ ) -> Union[Tensor, Tuple[Tensor, Dict[str, Tensor]]]:
24
+ device = a.device
25
+ na, nb = C.shape
26
+ assert na == a.shape[0] and nb == b.shape[0], f"Shapes of a ({a.shape}) or b ({b.shape}) do not match that of C ({C.shape})"
27
+ assert reg > 0, f"reg should be greater than 0. Found reg = {reg}"
28
+ assert a.min() >= 0. and b.min() >= 0., f"Elements in a and b should be nonnegative. Found a.min() = {a.min()}, b.min() = {b.min()}"
29
+
30
+ if log:
31
+ log = {"err": []}
32
+
33
+ u = torch.ones(na, dtype=a.dtype, device=device) / na
34
+ v = torch.ones(nb, dtype=b.dtype, device=device) / nb
35
+ K = torch.exp(-C / reg)
36
+
37
+ it, err = 1, 1
38
+ while (err > stopThr and it <= maxIter):
39
+ u_pre, v_pre = u.clone(), v.clone()
40
+ KTu = torch.matmul(K.T, u)
41
+ v = b / (KTu + M_EPS)
42
+ Kv = torch.matmul(K, v)
43
+ u = a / (Kv + M_EPS)
44
+
45
+ if torch.any(torch.isnan(u)) or torch.any(torch.isnan(v)) or torch.any(torch.isinf(u)) or torch.any(torch.isinf(v)):
46
+ print("Warning: numerical errors at iteration", it)
47
+ u, v = u_pre, v_pre
48
+ break
49
+
50
+ if log and it % eval_freq == 0:
51
+ b_hat = torch.matmul(u, K) * v
52
+ err = (b - b_hat).pow(2).sum().item()
53
+ log["err"].append(err)
54
+
55
+ if verbose and it % print_freq == 0:
56
+ print(f"Iteration {it}, constraint error {err}")
57
+
58
+ it += 1
59
+
60
+ if log:
61
+ log["u"] = u
62
+ log["v"] = v
63
+ log["alpha"] = reg * torch.log(u + M_EPS)
64
+ log["beta"] = reg * torch.log(v + M_EPS)
65
+
66
+ P = u.view(-1, 1) * K * v.view(1, -1)
67
+ if log:
68
+ return P, log
69
+ else:
70
+ return P
losses/dm_loss.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn, Tensor
3
+ from torch.amp import autocast
4
+ from typing import List, Tuple, Dict
5
+
6
+ from .bregman_pytorch import sinkhorn
7
+ from .utils import _reshape_density
8
+
9
+ EPS = 1e-8
10
+
11
+
12
+ class OTLoss(nn.Module):
13
+ def __init__(
14
+ self,
15
+ input_size: int,
16
+ block_size: int,
17
+ numItermax: int = 100,
18
+ regularization: float = 10.0
19
+ ) -> None:
20
+ super().__init__()
21
+ assert input_size % block_size == 0
22
+
23
+ self.input_size = input_size
24
+ self.block_size = block_size
25
+ self.num_blocks_h = input_size // block_size
26
+ self.num_blocks_w = input_size // block_size
27
+ self.numItermax = numItermax
28
+ self.regularization = regularization
29
+
30
+ # coordinate is same to image space, set to constant since crop size is same
31
+ self.coords_h = torch.arange(0, input_size, step=block_size, dtype=torch.float32) + block_size / 2
32
+ self.coords_w = torch.arange(0, input_size, step=block_size, dtype=torch.float32) + block_size / 2
33
+ self.coords_h, self.coords_w = self.coords_h.unsqueeze(0), self.coords_w.unsqueeze(0) # [1, #coordinates]
34
+
35
+ def set_numItermax(self, numItermax: int) -> None:
36
+ self.numItermax = numItermax
37
+
38
+ @autocast(device_type="cuda", enabled=True, dtype=torch.float32) # avoid numerical instability
39
+ def forward(self, pred_den_map: Tensor, pred_den_map_normed: Tensor, gt_points: List[Tensor]) -> Tuple[Tensor, Tensor, Tensor]:
40
+ assert pred_den_map.shape[1:] == pred_den_map_normed.shape[1:] == (1, self.num_blocks_h, self.num_blocks_w), f"Expected pred_den_map to have shape (B, 1, {self.num_blocks_h}, {self.num_blocks_w}), but got {pred_den_map.shape} and {pred_den_map_normed.shape}"
41
+ assert len(gt_points) == pred_den_map.shape[0] == pred_den_map_normed.shape[0], f"Expected gt_points to have length {pred_den_map_normed.shape[0]}, but got {len(gt_points)}"
42
+ device = pred_den_map.device
43
+
44
+ loss = torch.zeros(1, device=device)
45
+ ot_obj_values = torch.zeros(1, device=device)
46
+ w_dist = torch.zeros(1, device=device) # Wasserstein distance
47
+ coords_h, coords_w = self.coords_h.to(device), self.coords_w.to(device) # [1, #coordinates]
48
+ for idx, points in enumerate(gt_points):
49
+ if len(points) > 0:
50
+ # compute l2 square distance, it should be source target distance. [#gt, #coordinates * #coordinates]
51
+ x, y = points[:, 0].unsqueeze(1), points[:, 1].unsqueeze(1) # [#gt, 1]
52
+ x_dist = -2 * torch.matmul(x, coords_w) + x * x + coords_w * coords_w # [#gt, #coordinates]
53
+ y_dist = -2 * torch.matmul(y, coords_h) + y * y + coords_h * coords_h # [#gt, #coordinates]
54
+ dist = x_dist.unsqueeze(1) + y_dist.unsqueeze(2)
55
+ dist = dist.view((dist.shape[0], -1)) # size of [#gt, #coordinates * #coordinates]
56
+
57
+ source_prob = pred_den_map_normed[idx].view(-1).detach()
58
+ target_prob = (torch.ones(len(points)) / len(points)).to(device)
59
+ # use sinkhorn to solve OT, compute optimal beta.
60
+ P, log = sinkhorn(
61
+ a=target_prob,
62
+ b=source_prob,
63
+ C=dist,
64
+ reg=self.regularization,
65
+ maxIter=self.numItermax,
66
+ log=True
67
+ )
68
+ beta = log["beta"] # size is the same as source_prob: [#coordinates * #coordinates]
69
+ w_dist += (dist * P).sum()
70
+ ot_obj_values += (pred_den_map_normed[idx] * beta.view(1, self.num_blocks_h, self.num_blocks_w)).sum()
71
+ # compute the gradient of OT loss to predicted density (pred_den_map).
72
+ # im_grad = beta / source_count - < beta, source_density> / (source_count)^2
73
+ source_density = pred_den_map[idx].view(-1).detach()
74
+ source_count = source_density.sum()
75
+ gradient_1 = (source_count) / (source_count * source_count+ EPS) * beta # size of [#coordinates * #coordinates]
76
+ gradient_2 = (source_density * beta).sum() / (source_count * source_count + EPS) # size of 1
77
+ gradient = gradient_1 - gradient_2
78
+ gradient = gradient.detach().view(1, self.num_blocks_h, self.num_blocks_w)
79
+ # Define loss = <im_grad, predicted density>. The gradient of loss w.r.t predicted density is im_grad.
80
+ loss += torch.sum(pred_den_map[idx] * gradient)
81
+
82
+ return loss, w_dist, ot_obj_values
83
+
84
+
85
+ class DMLoss(nn.Module):
86
+ def __init__(
87
+ self,
88
+ input_size: int,
89
+ block_size: int,
90
+ numItermax: int = 100,
91
+ regularization: float = 10.0,
92
+ weight_ot: float = 0.1,
93
+ weight_tv: float = 0.01,
94
+ weight_cnt: float = 1.0,
95
+ ) -> None:
96
+ super().__init__()
97
+ self.input_size = input_size
98
+ self.block_size = block_size
99
+ self.weight_ot = weight_ot
100
+ self.weight_tv = weight_tv
101
+ self.weight_cnt = weight_cnt
102
+
103
+ self.ot_loss = OTLoss(
104
+ input_size=self.input_size,
105
+ block_size=self.block_size,
106
+ numItermax=numItermax,
107
+ regularization=regularization,
108
+ )
109
+ self.tv_loss = nn.L1Loss(reduction="none")
110
+ self.cnt_loss = nn.L1Loss(reduction="mean")
111
+ self.weight_ot = weight_ot
112
+ self.weight_tv = weight_tv
113
+
114
+ @autocast(device_type="cuda", enabled=True, dtype=torch.float32) # avoid numerical instability
115
+ def forward(self, pred_den_map: Tensor, gt_den_map: Tensor, gt_points: List[Tensor]) -> Tuple[Tensor, Dict[str, Tensor]]:
116
+ gt_den_map = _reshape_density(gt_den_map, block_size=self.ot_loss.block_size) if gt_den_map.shape[-2:] != pred_den_map.shape[-2:] else gt_den_map
117
+ assert pred_den_map.shape == gt_den_map.shape, f"Expected pred_den_map and gt_den_map to have the same shape, got {pred_den_map.shape} and {gt_den_map.shape}"
118
+
119
+ pred_cnt = pred_den_map.view(pred_den_map.shape[0], -1).sum(dim=1)
120
+ pred_den_map_normed = pred_den_map / (pred_cnt.view(-1, 1, 1, 1) + EPS)
121
+ gt_cnt = torch.tensor([len(p) for p in gt_points], dtype=torch.float32).to(pred_den_map.device)
122
+ gt_den_map_normed = gt_den_map / (gt_cnt.view(-1, 1, 1, 1) + EPS)
123
+
124
+ ot_loss, w_dist, _ = self.ot_loss(pred_den_map, pred_den_map_normed, gt_points)
125
+
126
+ tv_loss = (self.tv_loss(pred_den_map_normed, gt_den_map_normed).sum(dim=(1, 2, 3)) * gt_cnt).mean() if self.weight_tv > 0 else 0
127
+
128
+ cnt_loss = self.cnt_loss(pred_cnt, gt_cnt) if self.weight_cnt > 0 else 0
129
+
130
+ loss = ot_loss * self.weight_ot + tv_loss * self.weight_tv + cnt_loss * self.weight_cnt
131
+
132
+ loss_info = {
133
+ "ot_loss": ot_loss.detach(),
134
+ "dm_loss": loss.detach(),
135
+ "w_dist": w_dist.detach(),
136
+ }
137
+ if self.weight_tv > 0:
138
+ loss_info["tv_loss"] = tv_loss.detach()
139
+ if self.weight_cnt > 0:
140
+ loss_info["cnt_loss"] = cnt_loss.detach()
141
+
142
+ return loss, loss_info
losses/dual_loss.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn, Tensor
3
+ import torch.nn.functional as F
4
+ from typing import List, Tuple, Dict
5
+
6
+ from .dm_loss import DMLoss
7
+ from .multiscale_mae import MultiscaleMAE
8
+ from .utils import _reshape_density
9
+
10
+
11
+
12
+ class DualLoss(nn.Module):
13
+ def __init__(
14
+ self,
15
+ input_size: int,
16
+ block_size: int,
17
+ bins: List[Tuple[float, float]],
18
+ bin_centers: List[float],
19
+ cls_loss: str = "ce",
20
+ reg_loss: str = "dm",
21
+ weight_tv: float = 0.01,
22
+ weight_cls: float = 0.1,
23
+ weight_reg: float = 0.1,
24
+ numItermax: int = 100,
25
+ regularization: float = 10.0,
26
+ scales: List[int] = [1, 2, 4],
27
+ min_scale_weight: float = 0.25,
28
+ max_scale_weight: float = 0.75,
29
+ alpha: float = 0.5,
30
+ ) -> None:
31
+ super().__init__()
32
+ assert len(bins) == len(bin_centers) >= 2, f"Expected bins and bin_centers to have at least 2 elements, got {len(bins)} and {len(bin_centers)}"
33
+ assert all([len(b) == 2 for b in bins]), f"Expected all bins to be of length 2, got {bins}"
34
+ assert all(b[0] <= p <= b[1] for b, p in zip(bins, bin_centers)), f"Expected bin_centers to be within the range of the corresponding bin, got {bins} and {bin_centers}"
35
+ assert cls_loss in ["ce", "mae", "mse", "none"], f"Expected cls_loss to be one of ['ce', 'mae', 'mse', 'none'], got {cls_loss}"
36
+ assert reg_loss in ["dm", "msmae", "mae", "mse", "none"], f"Expected reg_loss to be one of ['dm', 'msmae', 'mae', 'mse', 'none'], got {reg_loss}"
37
+ assert not (cls_loss == "none" and reg_loss == "none"), "Expected at least one of cls_loss and reg_loss to be provided"
38
+ assert weight_cls is None or weight_cls >= 0, f"Expected weight_cls to be non-negative, got {weight_cls}"
39
+ assert weight_reg is None or weight_reg >= 0, f"Expected weight_reg to be non-negative, got {weight_reg}"
40
+ assert weight_tv is None or weight_tv >= 0, f"Expected weight_tv to be non-negative, got {weight_tv}"
41
+ assert min_scale_weight is None or max_scale_weight is None or max_scale_weight >= min_scale_weight > 0, f"Expected max_scale_weight to be greater than or equal to min_scale_weight, got {min_scale_weight} and {max_scale_weight}"
42
+ assert alpha is None or 1 > alpha > 0, f"Expected alpha to be between 0 and 1, got {alpha}"
43
+
44
+ if reg_loss == "dm":
45
+ assert numItermax is not None and numItermax > 0, f"Expected numItermax to be a positive integer, got {numItermax}"
46
+ assert regularization is not None and regularization > 0, f"Expected regularization to be a positive float, got {regularization}"
47
+ assert weight_tv is not None and weight_tv >= 0, f"Expected weight_tv to be non-negative, got {weight_tv}"
48
+ else:
49
+ weight_tv, numItermax, regularization = None, None, None
50
+
51
+ if reg_loss == "msmae":
52
+ assert isinstance(scales, (list, tuple)) and len(scales) > 0 and all(isinstance(s, int) and s > 0 for s in scales), f"Expected scales to be a list of positive integers, got {scales}"
53
+ assert max_scale_weight >= min_scale_weight > 0, f"Expected max_scale_weight to be greater than or equal to min_scale_weight, got {min_scale_weight} and {max_scale_weight}"
54
+ assert 1 > alpha > 0, f"Expected alpha to be between 0 and 1, got {alpha}"
55
+ else:
56
+ scales = None
57
+ min_scale_weight, max_scale_weight = None, None
58
+ alpha = None
59
+
60
+ weight_cls = weight_cls if weight_cls is not None else 0
61
+ weight_reg = weight_reg if weight_reg is not None else 0
62
+
63
+ self.input_size, self.block_size = input_size, block_size
64
+ self.num_blocks_h, self.num_blocks_w = input_size // block_size, input_size // block_size
65
+ self.bins, self.bin_centers, self.num_bins = bins, bin_centers, len(bins)
66
+ self.cls_loss, self.reg_loss = cls_loss, reg_loss
67
+ self.weight_cls, self.weight_reg = weight_cls, weight_reg
68
+ self.numItermax, self.regularization = numItermax, regularization
69
+ self.weight_tv = weight_tv
70
+ self.scales = scales
71
+ self.min_scale_weight, self.max_scale_weight = min_scale_weight, max_scale_weight
72
+
73
+ if cls_loss == "ce":
74
+ self.cls_loss_fn = nn.CrossEntropyLoss(reduction="none")
75
+ self.weight_cls = 1.0
76
+ elif cls_loss == "mae":
77
+ self.cls_loss_fn = nn.L1Loss(reduction="none")
78
+ self.weight_cls = weight_cls
79
+ elif cls_loss == "mse":
80
+ self.cls_loss_fn = nn.MSELoss(reduction="none")
81
+ self.weight_cls = weight_cls
82
+ else: # cls_loss == "none"
83
+ self.cls_loss_fn = None
84
+ self.weight_cls = 0
85
+
86
+ if reg_loss == "dm":
87
+ self.reg_loss_fn = DMLoss(
88
+ input_size=input_size,
89
+ block_size=block_size,
90
+ numItermax=numItermax,
91
+ regularization=regularization,
92
+ weight_ot=weight_reg,
93
+ weight_tv=weight_tv,
94
+ weight_cnt=0, # Calculate the count loss separately
95
+ )
96
+ self.weight_reg = 1.0
97
+ elif reg_loss == "msmae":
98
+ self.reg_loss_fn = MultiscaleMAE(scales=scales, weights=None, min_scale_weight=min_scale_weight, max_scale_weight=max_scale_weight, alpha=alpha)
99
+ self.weight_reg = 1.0
100
+ elif reg_loss == "mae":
101
+ self.reg_loss_fn = nn.L1Loss(reduction="none")
102
+ self.weight_reg = weight_reg
103
+ elif reg_loss == "mse":
104
+ self.reg_loss_fn = nn.MSELoss(reduction="none")
105
+ self.weight_reg = weight_reg
106
+ else:
107
+ self.reg_loss_fn = None
108
+ self.weight_reg = 0
109
+
110
+ self.cnt_loss_fn = nn.L1Loss(reduction="none")
111
+
112
+ def _bin_count(self, density_map: Tensor) -> Tensor:
113
+ class_map = torch.zeros_like(density_map, dtype=torch.long)
114
+ for idx, (low, high) in enumerate(self.bins):
115
+ mask = (density_map >= low) & (density_map <= high)
116
+ class_map[mask] = idx
117
+ return class_map.squeeze(1) # remove channel dimension
118
+
119
+ def forward(
120
+ self,
121
+ pred_logit_map: Tensor,
122
+ pred_den_map: Tensor,
123
+ gt_den_map: Tensor,
124
+ gt_points: List[Tensor]
125
+ ) -> Tuple[Tensor, Dict[str, Tensor]]:
126
+ B = pred_logit_map.shape[0]
127
+ assert pred_logit_map.shape == (B, self.num_bins, self.num_blocks_h, self.num_blocks_w), f"Expected pred_logit_map to have shape {B, self.num_bins, self.num_blocks_h, self.num_blocks_w}, got {pred_logit_map.shape}"
128
+ if gt_den_map.shape[-2:] != (self.num_blocks_h, self.num_blocks_w):
129
+ assert gt_den_map.shape[-2:] == (self.input_size, self.input_size), f"Expected gt_den_map to have shape {B, 1, self.input_size, self.input_size}, got {gt_den_map.shape}"
130
+ gt_den_map = _reshape_density(gt_den_map, block_size=self.block_size)
131
+ assert pred_den_map.shape == gt_den_map.shape == (B, 1, self.num_blocks_h, self.num_blocks_w), f"Expected pred_den_map and gt_den_map to have shape (B, 1, H, W), got {pred_den_map.shape} and {gt_den_map.shape}"
132
+ assert len(gt_points) == B, f"Expected gt_points to have length B, got {len(gt_points)}"
133
+
134
+ loss_info = {}
135
+
136
+ if self.weight_cls > 0:
137
+ gt_class_map = self._bin_count(gt_den_map)
138
+ if self.cls_loss == "ce":
139
+ cls_loss = self.cls_loss_fn(pred_logit_map, gt_class_map).sum(dim=(-1, -2)).mean()
140
+ loss_info["cls_ce_loss"] = cls_loss.detach()
141
+ else: # self.cls_loss in ["mae", "mse"]
142
+ gt_prob_map = F.one_hot(gt_class_map, num_classes=self.num_bins).float() # B, H, W -> B, H, W, N
143
+ gt_prob_map = gt_prob_map.permute(0, 3, 1, 2) # B, H, W, N -> B, N, H, W
144
+ pred_prob_map = pred_logit_map.softmax(dim=1)
145
+ cls_loss = self.cls_loss_fn(pred_prob_map, gt_prob_map).sum(dim=(-1, -2)).mean()
146
+ loss_info[f"cls_{self.cls_loss}_loss"] = cls_loss.detach()
147
+ else:
148
+ cls_loss = 0
149
+
150
+ if self.weight_reg > 0:
151
+ if self.reg_loss == "dm":
152
+ reg_loss, reg_loss_info = self.reg_loss_fn(
153
+ pred_den_map=pred_den_map,
154
+ gt_den_map=gt_den_map,
155
+ gt_points=gt_points,
156
+ )
157
+ loss_info.update({f"reg_{k}": v for k, v in reg_loss_info.items()})
158
+ elif self.reg_loss == "msmae":
159
+ reg_loss, reg_loss_info = self.reg_loss_fn(pred_den_map, gt_den_map)
160
+ loss_info.update({f"reg_{k}": v for k, v in reg_loss_info.items()})
161
+ else: # self.reg_loss in ["mae", "mse"]
162
+ reg_loss = self.reg_loss_fn(pred_den_map, gt_den_map).sum(dim=(-1, -2)).mean()
163
+ loss_info[f"reg_{self.reg_loss}_loss"] = reg_loss.detach()
164
+ else:
165
+ reg_loss = 0
166
+
167
+ gt_cnt = torch.tensor([len(p) for p in gt_points], dtype=torch.float32, device=pred_den_map.device)
168
+ cnt_loss = self.cnt_loss_fn(pred_den_map.sum(dim=(1, 2, 3)), gt_cnt).mean()
169
+ loss_info["cnt_loss"] = cnt_loss.detach()
170
+
171
+ total_loss = self.weight_cls * cls_loss + self.weight_reg * reg_loss + cnt_loss
172
+ loss_info["total_loss"] = total_loss.detach()
173
+ loss_info = dict(sorted(loss_info.items())) # sort by key for nicer printing
174
+
175
+ return total_loss, loss_info
losses/loss.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn, Tensor
3
+ import torch.nn.functional as F
4
+ from typing import List, Dict, Optional, Tuple, Union
5
+
6
+ from .dm_loss import DMLoss
7
+ from .multiscale_mae import MultiscaleMAE
8
+ from .poisson_nll import PoissonNLL
9
+ from .zero_inflated_poisson_nll import ZIPoissonNLL, ZICrossEntropy
10
+ from .utils import _reshape_density, _bin_count
11
+
12
+
13
+ EPS = 1e-8
14
+
15
+
16
+ class QuadLoss(nn.Module):
17
+ def __init__(
18
+ self,
19
+ input_size: int,
20
+ block_size: int,
21
+ bins: List[Tuple[float, float]],
22
+ reg_loss: str = "zipnll",
23
+ aux_loss: str = "none",
24
+ weight_cls: float = 1.0,
25
+ weight_reg: float = 1.0,
26
+ weight_aux: Optional[float] = None,
27
+ numItermax: Optional[int] = 100,
28
+ regularization: Optional[int] = 10.0,
29
+ scales: Optional[List[int]] = [[1, 2, 4]],
30
+ min_scale_weight: Optional[float] = 0.0,
31
+ max_scale_weight: Optional[float] = 1.0,
32
+ alpha: Optional[float] = 0.5,
33
+ ) -> None:
34
+ super().__init__()
35
+ assert input_size % block_size == 0, f"Expected input_size to be divisible by block_size, got {input_size} and {block_size}"
36
+ assert len(bins) >= 2, f"Expected bins to have at least 2 elements, got {len(bins)}"
37
+ assert all([len(b) == 2 for b in bins]), f"Expected all bins to be of length 2, got {bins}"
38
+ bins = [(float(low), float(high)) for low, high in bins]
39
+ assert all([b[0] <= b[1] for b in bins]), f"Expected each bin to have bin[0] <= bin[1], got {bins}"
40
+ assert reg_loss in ["zipnll", "pnll", "dm", "msmae", "mae", "mse"], f"Expected reg_loss to be one of ['zipnll', 'pnll', 'dm', 'msmae', 'mae', 'mse'], got {reg_loss}"
41
+ assert aux_loss in ["zipnll", "pnll", "dm", "msmae", "mae", "mse", "none"], f"Expected aux_loss to be one of ['zipnll', 'pnll', 'dm', 'msmae', 'mae', 'mse', 'none'], got {aux_loss}"
42
+
43
+ assert weight_cls >= 0, f"Expected weight_cls to be non-negative, got {weight_cls}"
44
+ assert weight_reg >= 0, f"Expected weight_reg to be non-negative, got {weight_reg}"
45
+ assert not (weight_cls == 0 and weight_reg == 0), "Expected at least one of weight_cls and weight_reg to be non-zero"
46
+ weight_aux = 0 if aux_loss == "none" or weight_aux is None else weight_aux
47
+ assert weight_aux >= 0, f"Expected weight_aux to be non-negative, got {weight_aux}"
48
+
49
+ self.input_size = input_size
50
+ self.block_size = block_size
51
+ self.bins = bins
52
+ self.reg_loss = reg_loss
53
+ self.aux_loss = aux_loss
54
+ self.weight_cls = weight_cls
55
+ self.weight_reg = weight_reg
56
+ self.weight_aux = weight_aux
57
+
58
+ self.num_bins = len(bins)
59
+ self.num_blocks_h = input_size // block_size
60
+ self.num_blocks_w = input_size // block_size
61
+
62
+ if reg_loss == "zipnll":
63
+ self.cls_loss = "zice"
64
+ self.cls_loss_fn = ZICrossEntropy(bins=bins, reduction="mean")
65
+ self.reg_loss_fn = ZIPoissonNLL(reduction="mean")
66
+ else:
67
+ self.cls_loss = "ce"
68
+ self.cls_loss_fn = nn.CrossEntropyLoss(reduction="none")
69
+ if reg_loss == "pnll":
70
+ self.reg_loss_fn = PoissonNLL(reduction="mean")
71
+ elif reg_loss == "dm":
72
+ assert numItermax is not None and numItermax > 0, f"Expected numItermax to be a positive integer, got {numItermax}"
73
+ assert regularization is not None and regularization > 0, f"Expected regularization to be a positive float, got {regularization}"
74
+ self.reg_loss_fn = DMLoss(
75
+ input_size=input_size,
76
+ block_size=block_size,
77
+ numItermax=numItermax,
78
+ regularization=regularization,
79
+ weight_ot=0.1,
80
+ weight_tv=0.01,
81
+ weight_cnt=0, # count loss will be calculated separately in this module.
82
+ )
83
+ elif reg_loss == "msmae":
84
+ assert isinstance(scales, (list, tuple)) and len(scales) > 0 and all(isinstance(s, int) and s > 0 for s in scales), f"Expected scales to be a list of positive integers, got {scales}"
85
+ assert max_scale_weight >= min_scale_weight >= 0, f"Expected max_scale_weight to be greater than or equal to min_scale_weight, got {min_scale_weight} and {max_scale_weight}"
86
+ assert 1 > alpha > 0, f"Expected alpha to be between 0 and 1, got {alpha}"
87
+ self.reg_loss_fn = MultiscaleMAE(
88
+ scales=sorted(scales),
89
+ min_scale_weight=min_scale_weight,
90
+ max_scale_weight=max_scale_weight,
91
+ alpha=alpha,
92
+ )
93
+ elif reg_loss == "mae":
94
+ self.reg_loss_fn = nn.L1Loss(reduction="none")
95
+ elif reg_loss == "mse":
96
+ self.reg_loss_fn = nn.MSELoss(reduction="none")
97
+ else: # reg_loss == "none"
98
+ self.reg_loss_fn = None
99
+
100
+ if aux_loss == "zipnll":
101
+ self.aux_loss_fn = ZIPoissonNLL(reduction="mean")
102
+ elif aux_loss == "pnll":
103
+ self.aux_loss_fn = PoissonNLL(reduction="mean")
104
+ elif aux_loss == "dm":
105
+ assert numItermax is not None and numItermax > 0, f"Expected numItermax to be a positive integer, got {numItermax}"
106
+ assert regularization is not None and regularization > 0, f"Expected regularization to be a positive float, got {regularization}"
107
+ self.aux_loss_fn = DMLoss(
108
+ input_size=input_size,
109
+ block_size=block_size,
110
+ numItermax=numItermax,
111
+ regularization=regularization,
112
+ weight_ot=0.1,
113
+ weight_tv=0.01,
114
+ weight_cnt=0, # count loss will be calculated separately in this module.
115
+ )
116
+ elif aux_loss == "msmae":
117
+ assert isinstance(scales, (list, tuple)) and len(scales) > 0 and all(isinstance(s, int) and s > 0 for s in scales), f"Expected scales to be a list of positive integers, got {scales}"
118
+ assert max_scale_weight >= min_scale_weight >= 0, f"Expected max_scale_weight to be greater than or equal to min_scale_weight, got {min_scale_weight} and {max_scale_weight}"
119
+ assert 1 > alpha > 0, f"Expected alpha to be between 0 and 1, got {alpha}"
120
+ self.aux_loss_fn = MultiscaleMAE(
121
+ scales=sorted(scales),
122
+ min_scale_weight=min_scale_weight,
123
+ max_scale_weight=max_scale_weight,
124
+ alpha=alpha,
125
+ )
126
+ elif aux_loss == "mae":
127
+ self.aux_loss_fn = nn.L1Loss(reduction="none")
128
+ elif aux_loss == "mse":
129
+ self.aux_loss_fn = nn.MSELoss(reduction="none")
130
+ else: # aux_loss == "none"
131
+ self.aux_loss_fn = None
132
+
133
+ self.cnt_loss_fn = nn.L1Loss(reduction="mean")
134
+
135
+ def forward(
136
+ self,
137
+ pred_logit_map: Tensor,
138
+ pred_den_map: Tensor,
139
+ gt_den_map: Tensor,
140
+ gt_points: List[Tensor],
141
+ pred_logit_pi_map: Optional[Tensor] = None,
142
+ pred_lambda_map: Optional[Tensor] = None,
143
+ ) -> Tuple[Tensor, Dict[str, Tensor]]:
144
+ B = pred_den_map.shape[0]
145
+ assert pred_logit_map.shape[-2:] == (self.num_blocks_h, self.num_blocks_w), f"Expected pred_logit_map to have the spatial dimension of {self.num_blocks_h}x{self.num_blocks_w}, got {pred_logit_map.shape}"
146
+ if gt_den_map.shape[-2:] != (self.num_blocks_h, self.num_blocks_w):
147
+ assert gt_den_map.shape[-2:] == (self.input_size, self.input_size), f"Expected gt_den_map to have shape {B, 1, self.input_size, self.input_size}, got {gt_den_map.shape}"
148
+ gt_den_map = _reshape_density(gt_den_map, block_size=self.block_size)
149
+ assert pred_den_map.shape == gt_den_map.shape == (B, 1, self.num_blocks_h, self.num_blocks_w), f"Expected pred_den_map and gt_den_map to have shape (B, 1, H, W), got {pred_den_map.shape} and {gt_den_map.shape}"
150
+ assert len(gt_points) == B, f"Expected gt_points to have length B, got {len(gt_points)}"
151
+
152
+ if self.reg_loss == "zipnll" or self.aux_loss == "zipnll":
153
+ assert pred_logit_pi_map is not None and pred_logit_pi_map.shape == (B, 2, self.num_blocks_h, self.num_blocks_w), f"Expected pred_logit_pi_map to have shape {B, 2, self.num_blocks_h, self.num_blocks_w}, got {pred_logit_pi_map.shape}"
154
+ assert pred_lambda_map is not None and pred_lambda_map.shape == (B, 1, self.num_blocks_h, self.num_blocks_w), f"Expected pred_lambda_map to have shape {B, 1, self.num_blocks_h, self.num_blocks_w}, got {pred_lambda_map.shape}"
155
+
156
+ loss_info = {}
157
+ if self.weight_cls > 0:
158
+ gt_class_map = _bin_count(gt_den_map, bins=self.bins)
159
+ if self.cls_loss == "ce":
160
+ cls_loss = self.cls_loss_fn(pred_logit_map, gt_class_map).sum(dim=(-1, -2)).mean()
161
+ loss_info["cls_ce_loss"] = cls_loss.detach()
162
+ else: # cls_loss == "zice"
163
+ cls_loss, cls_loss_info = self.cls_loss_fn(pred_logit_map, gt_den_map)
164
+ loss_info.update(cls_loss_info)
165
+ else:
166
+ cls_loss = 0
167
+
168
+ if self.weight_reg > 0:
169
+ if self.reg_loss == "zipnll":
170
+ reg_loss, reg_loss_info = self.reg_loss_fn(pred_logit_pi_map, pred_lambda_map, gt_den_map)
171
+ elif self.reg_loss == "dm":
172
+ reg_loss, reg_loss_info = self.reg_loss_fn(pred_den_map, gt_den_map, gt_points)
173
+ elif self.reg_loss in ["pnll", "msmae"]:
174
+ reg_loss, reg_loss_info = self.reg_loss_fn(pred_den_map, gt_den_map)
175
+ else: # reg_loss in ["mae", "mse"]
176
+ reg_loss = self.reg_loss_fn(pred_den_map, gt_den_map).sum(dim=(-1, -2)).mean()
177
+ reg_loss_info = {f"{self.reg_loss}": reg_loss.detach()}
178
+ reg_loss_info = {f"reg_{k}": v for k, v in reg_loss_info.items()}
179
+ loss_info.update(reg_loss_info)
180
+ else:
181
+ reg_loss = 0
182
+
183
+ if self.weight_aux > 0:
184
+ if self.aux_loss == "zipnll":
185
+ aux_loss, aux_loss_info = self.aux_loss_fn(pred_logit_pi_map, pred_lambda_map, gt_den_map)
186
+ elif self.aux_loss in ["pnll", "msmae"]:
187
+ aux_loss, aux_loss_info = self.aux_loss_fn(pred_den_map, gt_den_map)
188
+ elif self.aux_loss == "dm":
189
+ aux_loss, aux_loss_info = self.aux_loss_fn(pred_den_map, gt_den_map, gt_points)
190
+ else:
191
+ aux_loss = self.aux_loss_fn(pred_den_map, gt_den_map).sum(dim=(-1, -2)).mean()
192
+ aux_loss_info = {f"{self.aux_loss}": aux_loss.detach()}
193
+ aux_loss_info = {f"aux_{k}": v for k, v in aux_loss_info.items()}
194
+ loss_info.update(aux_loss_info)
195
+ else:
196
+ aux_loss = 0
197
+
198
+ gt_cnt = torch.tensor([len(p) for p in gt_points], dtype=torch.float32, device=pred_den_map.device)
199
+ cnt_loss = self.cnt_loss_fn(pred_den_map.sum(dim=(1, 2, 3)), gt_cnt)
200
+ loss_info["cnt_loss"] = cnt_loss.detach()
201
+
202
+ total_loss = self.weight_cls * cls_loss + self.weight_reg * reg_loss + self.weight_aux * aux_loss + cnt_loss
203
+ return total_loss, loss_info
204
+
losses/multiscale_mae.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn, Tensor
2
+ import math
3
+ from typing import List, Optional, Dict, Tuple
4
+
5
+
6
+ class MultiscaleMAE(nn.Module):
7
+ def __init__(
8
+ self,
9
+ scales: List[int] = [1, 2, 4],
10
+ min_scale_weight: float = 0.0,
11
+ max_scale_weight: float = 1.0,
12
+ alpha: float = 0.5,
13
+ weights: Optional[List[float]] = None,
14
+ ) -> None:
15
+ super().__init__()
16
+ assert isinstance(scales, (list, tuple)) and len(scales) > 0 and all(isinstance(s, int) and s > 0 for s in scales), f"Expected scales to be a list of positive integers, got {scales}"
17
+ assert max_scale_weight >= min_scale_weight >= 0, f"Expected max_scale_weight to be greater than or equal to min_scale_weight, got {min_scale_weight} and {max_scale_weight}"
18
+ assert 1 > alpha > 0, f"Expected alpha to be between 0 and 1, got {alpha}"
19
+ self.min_scale_weight, self.max_scale_weight = min_scale_weight, max_scale_weight
20
+
21
+ scales = sorted(scales) # sort scales in ascending order so that the last one is the largest
22
+ weights = [min_scale_weight + (max_scale_weight - min_scale_weight) * alpha ** (math.log2(scales[-1] / s)) for s in scales] if weights is None else weights # e.g., [1, 2, 4, 8] -> [0.125, 0.25, 0.5, 1]
23
+
24
+ assert len(scales) == len(weights), f"Expected scales and weights to have the same length, got {len(scales)} and {len(weights)}"
25
+ self.scales, self.weights = scales, weights
26
+
27
+ for idx in range(len(scales)):
28
+ setattr(self, f"pool_{scales[idx]}", nn.AvgPool2d(kernel_size=scales[idx], stride=scales[idx]) if scales[idx] > 1 else nn.Identity())
29
+ setattr(self, f"weight_{scales[idx]}", weights[idx])
30
+ setattr(self, f"mae_loss_fn_{scales[idx]}", nn.L1Loss(reduction="none"))
31
+
32
+ def forward(
33
+ self,
34
+ pred_den_map: Tensor,
35
+ gt_den_map: Tensor,
36
+ ) -> Tuple[Tensor, Dict]:
37
+ assert len(pred_den_map.shape) == 4, f"Expected pred_den_map to have 4 dimensions, got {len(pred_den_map.shape)}"
38
+ assert len(gt_den_map.shape) == 4, f"Expected gt_den_map to have 4 dimensions, got {len(gt_den_map.shape)}"
39
+ assert pred_den_map.shape[1] == gt_den_map.shape[1] == 1, f"Expected pred_den_map and gt_den_map to have 1 channel, got {pred_den_map.shape[1]} and {gt_den_map.shape[1]}"
40
+ assert pred_den_map.shape == gt_den_map.shape, f"Expected pred_den_map and gt_den_map to have the same shape, got {pred_den_map.shape} and {gt_den_map.shape}"
41
+
42
+ loss, loss_info = 0, {}
43
+ for idx in range(len(self.scales)):
44
+ pool = getattr(self, f"pool_{self.scales[idx]}")
45
+ weight = getattr(self, f"weight_{self.scales[idx]}")
46
+ loss_fn = getattr(self, f"mae_loss_fn_{self.scales[idx]}")
47
+
48
+ pred_den_map_pool = pool(pred_den_map)
49
+ gt_den_map_pool = pool(gt_den_map)
50
+
51
+ mae_loss_scale = loss_fn(pred_den_map_pool, gt_den_map_pool).sum(dim=(-1, -2)).mean()
52
+ loss += weight * mae_loss_scale
53
+ loss_info[f"mae_loss_{self.scales[idx]}"] = mae_loss_scale.detach()
54
+
55
+ return loss, loss_info
losses/poisson_nll.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn, Tensor
3
+ from .utils import _reshape_density
4
+
5
+
6
+ EPS = 1e-8
7
+
8
+
9
+ class PoissonNLL(nn.Module):
10
+ def __init__(
11
+ self,
12
+ reduction: str = "mean",
13
+ ) -> None:
14
+ super().__init__()
15
+ assert reduction in ["none", "mean", "sum"], f"Expected reduction to be one of ['none', 'mean', 'sum'], got {reduction}."
16
+ self.reduction = reduction
17
+
18
+ def forward(self, pred_den_map: Tensor, gt_den_map: Tensor) -> Tensor:
19
+ """
20
+ Args:
21
+ pred_den_map: predicted λ map, shape (B, 1, H, W)
22
+ gt_den_map: ground truth density map, shape (B, 1, H, W)
23
+ Returns:
24
+ Poisson loss
25
+ """
26
+ assert len(pred_den_map.shape) == 4, f"Expected pred_den_map to have 4 dimensions, got {len(pred_den_map.shape)}"
27
+ assert len(gt_den_map.shape) == 4, f"Expected gt_den_map to have 4 dimensions, got {len(gt_den_map.shape)}"
28
+ assert pred_den_map.shape[1] == gt_den_map.shape[1] == 1, f"Expected pred_den_map and gt_den_map to have 1 channel, got {pred_den_map.shape[1]} and {gt_den_map.shape[1]}"
29
+ if gt_den_map.shape != pred_den_map.shape:
30
+ gt_h, gt_w = gt_den_map.shape[-2], gt_den_map.shape[-1]
31
+ pred_h, pred_w = pred_den_map.shape[-2], pred_den_map.shape[-1]
32
+ assert gt_h % pred_h == 0 and gt_w % pred_w == 0 and gt_h // pred_h == gt_w // pred_w, f"Expected the spatial dimension of gt_den_map to be a multiple of that of pred_den_map, got {gt_den_map.shape} and {pred_den_map.shape}"
33
+ gt_den_map = _reshape_density(gt_den_map, block_size=gt_h // pred_h)
34
+
35
+ assert gt_den_map.shape == pred_den_map.shape, f"Expected gt_den_map and pred_den_map to have the same shape, got {gt_den_map.shape} and {pred_den_map.shape}"
36
+
37
+ gt_den_map = gt_den_map.to(pred_den_map.device)
38
+
39
+ loss = (pred_den_map - gt_den_map * torch.log(pred_den_map + EPS)).sum(dim=(-1, -2)) # sum over H and W
40
+
41
+ if self.reduction == "mean":
42
+ loss = loss.mean()
43
+ elif self.reduction == "sum":
44
+ loss = loss.sum()
45
+
46
+ return loss, {"pnll": loss.detach()}
losses/utils.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import Tensor
3
+ from typing import List, Tuple
4
+
5
+
6
+ def _reshape_density(density: Tensor, block_size: int) -> Tensor:
7
+ assert len(density.shape) == 4, f"Expected 4D (B, 1, H, W) tensor, got {density.shape}"
8
+ assert density.shape[1] == 1, f"Expected 1 channel, got {density.shape[1]}"
9
+ assert density.shape[2] % block_size == 0, f"Expected height to be divisible by {block_size}, got {density.shape[2]}"
10
+ assert density.shape[3] % block_size == 0, f"Expected width to be divisible by {block_size}, got {density.shape[3]}"
11
+ return density.reshape(density.shape[0], 1, density.shape[2] // block_size, block_size, density.shape[3] // block_size, block_size).sum(dim=(-1, -3))
12
+
13
+
14
+ def _bin_count(density_map: Tensor, bins: List[Tuple[int, int]]) -> Tensor:
15
+ class_map = torch.zeros_like(density_map, dtype=torch.long)
16
+ for idx, (low, high) in enumerate(bins):
17
+ mask = (density_map >= low) & (density_map <= high)
18
+ class_map[mask] = idx
19
+ return class_map.squeeze(1) # remove channel dimension
losses/zero_inflated_poisson_nll.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn, Tensor
3
+ from einops import rearrange
4
+ from typing import List, Tuple
5
+ from .utils import _reshape_density, _bin_count
6
+
7
+ EPS = 1e-8
8
+
9
+
10
+ class ZIPoissonNLL(nn.Module):
11
+ def __init__(
12
+ self,
13
+ reduction: str = "mean",
14
+ ) -> None:
15
+ super().__init__()
16
+ assert reduction in ["none", "mean", "sum"], f"Expected reduction to be one of ['none', 'mean', 'sum'], got {reduction}."
17
+ self.reduction = reduction
18
+
19
+ def forward(
20
+ self,
21
+ logit_pi_maps: Tensor,
22
+ lambda_maps: Tensor,
23
+ gt_den_maps: Tensor,
24
+ ) -> Tensor:
25
+ assert len(logit_pi_maps.shape) == len(lambda_maps.shape) == len(gt_den_maps.shape) == 4, f"Expected 4D (B, C, H, W) tensor, got {logit_pi_maps.shape}, {lambda_maps.shape}, and {gt_den_maps.shape}"
26
+ B, _, H, W = lambda_maps.shape
27
+ assert logit_pi_maps.shape == (B, 2, H, W), f"Expected logit_pi_maps to have shape (B, 2, H, W), got {logit_pi_maps.shape}"
28
+ assert lambda_maps.shape == (B, 1, H, W), f"Expected lambda_maps to have shape (B, 1, H, W), got {lambda_maps.shape}"
29
+ if gt_den_maps.shape[2:] != (H, W):
30
+ gt_h, gt_w = gt_den_maps.shape[-2], gt_den_maps.shape[-1]
31
+ assert gt_h % H == 0 and gt_w % W == 0 and gt_h // H == gt_w // W, f"Expected the spatial dimension of gt_den_maps to be a multiple of that of lambda_maps, got {gt_den_maps.shape} and {lambda_maps.shape}"
32
+ gt_den_maps = _reshape_density(gt_den_maps, block_size=gt_h // H)
33
+ assert gt_den_maps.shape == (B, 1, H, W), f"Expected gt_den_maps to have shape (B, 1, H, W), got {gt_den_maps.shape}"
34
+
35
+ pi_maps = logit_pi_maps.softmax(dim=1)
36
+ zero_indices = (gt_den_maps == 0).float()
37
+ zero_loss = -torch.log(pi_maps[:, 0:1] + pi_maps[:, 1:] * torch.exp(-lambda_maps) + EPS) * zero_indices
38
+
39
+ poisson_log_p = gt_den_maps * torch.log(lambda_maps + EPS) - lambda_maps
40
+ nonzero_loss = (-torch.log(pi_maps[:, 1:] + EPS) - poisson_log_p) * (1.0 - zero_indices)
41
+
42
+ loss = (zero_loss + nonzero_loss).sum(dim=(-1, -2))
43
+ if self.reduction == "mean":
44
+ loss = loss.mean()
45
+ elif self.reduction == "sum":
46
+ loss = loss.sum()
47
+
48
+ return loss, {"zipnll": loss.detach()}
49
+
50
+
51
+ class ZICrossEntropy(nn.Module):
52
+ def __init__(
53
+ self,
54
+ bins: List[Tuple[int, int]],
55
+ reduction: str = "mean",
56
+ ) -> None:
57
+ super().__init__()
58
+ assert all([low <= high for low, high in bins]), f"Expected bins to be a list of tuples (low, high) where low <= high, got {bins}"
59
+ assert reduction in ["mean", "sum"], f"Expected reduction to be one of ['none', 'mean', 'sum'], got {reduction}."
60
+
61
+ self.bins = bins
62
+ self.reduction = reduction
63
+ self.ce_loss_fn = nn.CrossEntropyLoss(reduction="none")
64
+
65
+ def forward(
66
+ self,
67
+ logit_maps: Tensor,
68
+ gt_den_maps: Tensor,
69
+ ) -> Tensor:
70
+ assert len(logit_maps.shape) == len(gt_den_maps.shape) == 4, f"Expected 4D (B, C, H, W) tensor, got {logit_maps.shape} and {gt_den_maps.shape}"
71
+ B, _, H, W = logit_maps.shape
72
+ assert logit_maps.shape[0] == B and logit_maps.shape[2:] == (H, W), f"Expected logit_maps to have shape (B, C, H, W), got {logit_maps.shape}"
73
+ if gt_den_maps.shape[2:] != (H, W):
74
+ gt_h, gt_w = gt_den_maps.shape[-2], gt_den_maps.shape[-1]
75
+ assert gt_h % H == 0 and gt_w % W == 0 and gt_h // H == gt_w // W, f"Expected the spatial dimension of gt_den_maps to be a multiple of that of logit_maps, got {gt_den_maps.shape} and {logit_maps.shape}"
76
+ gt_den_maps = _reshape_density(gt_den_maps, block_size=gt_h // H)
77
+ assert gt_den_maps.shape == (B, 1, H, W), f"Expected gt_den_maps to have shape (B, 1, H, W), got {gt_den_maps.shape}"
78
+
79
+ gt_class_maps = _bin_count(gt_den_maps, bins=self.bins)
80
+ gt_class_maps = rearrange(gt_class_maps, "B H W -> B (H W)") # flatten spatial dimensions
81
+ logit_maps = rearrange(logit_maps, "B C H W -> B (H W) C") # flatten spatial dimensions
82
+
83
+ loss = 0.0
84
+ for idx in range(gt_class_maps.shape[0]):
85
+ gt_class_map, logit_map = gt_class_maps[idx], logit_maps[idx]
86
+ mask = gt_class_map > 0
87
+ # Find gt_class_map values and logit_maps values where gt_class_map > 0
88
+ gt_class_map = gt_class_map[mask] - 1
89
+ logit_map = logit_map[mask]
90
+ loss += self.ce_loss_fn(logit_map, gt_class_map).sum()
91
+
92
+ if self.reduction == "mean":
93
+ loss /= gt_class_maps.shape[0]
94
+
95
+ return loss, {"cls_zice": loss.detach()}
96
+
models/__init__.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, torch
2
+ from typing import List, Tuple, Optional, Union, Dict
3
+
4
+ from .ebc import _ebc, EBC
5
+ from .clip_ebc import _clip_ebc, CLIP_EBC
6
+
7
+
8
+ def get_model(
9
+ model_info_path: str,
10
+ model_name: Optional[str] = None,
11
+ block_size: Optional[int] = None,
12
+ bins: Optional[List[Tuple[float, float]]] = None,
13
+ bin_centers: Optional[List[float]] = None,
14
+ zero_inflated: Optional[bool] = True,
15
+ # parameters for CLIP_EBC
16
+ clip_weight_name: Optional[str] = None,
17
+ num_vpt: Optional[int] = None,
18
+ vpt_drop: Optional[float] = None,
19
+ input_size: Optional[int] = None,
20
+ adapter: bool = False,
21
+ adapter_reduction: Optional[int] = None,
22
+ lora: bool = False,
23
+ lora_rank: Optional[int] = None,
24
+ lora_alpha: Optional[int] = None,
25
+ lora_dropout: Optional[float] = None,
26
+ norm: str = "none",
27
+ act: str = "none",
28
+ text_prompts: Optional[List[str]] = None
29
+ ) -> Union[EBC, CLIP_EBC]:
30
+ if os.path.exists(model_info_path):
31
+ model_info = torch.load(model_info_path, map_location="cpu", weights_only=False)
32
+
33
+ model_name = model_info["config"]["model_name"]
34
+ block_size = model_info["config"]["block_size"]
35
+ bins = model_info["config"]["bins"]
36
+ bin_centers = model_info["config"]["bin_centers"]
37
+ zero_inflated = model_info["config"]["zero_inflated"]
38
+
39
+ clip_weight_name = model_info["config"].get("clip_weight_name", None)
40
+
41
+ num_vpt = model_info["config"].get("num_vpt", None)
42
+ vpt_drop = model_info["config"].get("vpt_drop", None)
43
+
44
+
45
+ adapter = model_info["config"].get("adapter", False)
46
+ adapter_reduction = model_info["config"].get("adapter_reduction", None)
47
+
48
+ lora = model_info["config"].get("lora", False)
49
+ lora_rank = model_info["config"].get("lora_rank", None)
50
+ lora_alpha = model_info["config"].get("lora_alpha", None)
51
+ lora_dropout = model_info["config"].get("lora_dropout", None)
52
+
53
+ input_size = model_info["config"].get("input_size", None)
54
+ text_prompts = model_info["config"].get("text_prompts", None)
55
+
56
+ norm = model_info["config"].get("norm", "none")
57
+ act = model_info["config"].get("act", "none")
58
+
59
+ weights = model_info["weights"]
60
+
61
+ else:
62
+ assert model_name is not None, "model_name should be provided if model_info_path is not provided"
63
+ assert block_size is not None, "block_size should be provided"
64
+ assert bins is not None, "bins should be provided"
65
+ assert bin_centers is not None, "bin_centers should be provided"
66
+ weights = None
67
+
68
+ if "ViT" in model_name:
69
+ assert num_vpt is not None, f"num_vpt should be provided for ViT models, got {num_vpt}"
70
+ assert vpt_drop is not None, f"vpt_drop should be provided for ViT models, got {vpt_drop}"
71
+
72
+ if model_name.startswith("CLIP_") or model_name.startswith("CLIP-"):
73
+ assert clip_weight_name is not None, f"clip_weight_name should be provided for CLIP models, got {clip_weight_name}"
74
+ model = _clip_ebc(
75
+ model_name=model_name[5:],
76
+ weight_name=clip_weight_name,
77
+ block_size=block_size,
78
+ bins=bins,
79
+ bin_centers=bin_centers,
80
+ zero_inflated=zero_inflated,
81
+ num_vpt=num_vpt,
82
+ vpt_drop=vpt_drop,
83
+ input_size=input_size,
84
+ adapter=adapter,
85
+ adapter_reduction=adapter_reduction,
86
+ lora=lora,
87
+ lora_rank=lora_rank,
88
+ lora_alpha=lora_alpha,
89
+ lora_dropout=lora_dropout,
90
+ text_prompts=text_prompts,
91
+ norm=norm,
92
+ act=act
93
+ )
94
+ model_config = {
95
+ "model_name": model_name,
96
+ "block_size": block_size,
97
+ "bins": bins,
98
+ "bin_centers": bin_centers,
99
+ "zero_inflated": zero_inflated,
100
+ "clip_weight_name": clip_weight_name,
101
+ "num_vpt": num_vpt,
102
+ "vpt_drop": vpt_drop,
103
+ "input_size": input_size,
104
+ "adapter": adapter,
105
+ "adapter_reduction": adapter_reduction,
106
+ "lora": lora,
107
+ "lora_rank": lora_rank,
108
+ "lora_alpha": lora_alpha,
109
+ "lora_dropout": lora_dropout,
110
+ "text_prompts": model.text_prompts,
111
+ "norm": norm,
112
+ "act": act
113
+ }
114
+
115
+ else:
116
+ assert not adapter, "adapter for non-CLIP models is not implemented yet"
117
+ assert not lora, "lora for non-CLIP models is not implemented yet"
118
+ model = _ebc(
119
+ model_name=model_name,
120
+ block_size=block_size,
121
+ bins=bins,
122
+ bin_centers=bin_centers,
123
+ zero_inflated=zero_inflated,
124
+ num_vpt=num_vpt,
125
+ vpt_drop=vpt_drop,
126
+ input_size=input_size,
127
+ norm=norm,
128
+ act=act
129
+ )
130
+ model_config = {
131
+ "model_name": model_name,
132
+ "block_size": block_size,
133
+ "bins": bins,
134
+ "bin_centers": bin_centers,
135
+ "zero_inflated": zero_inflated,
136
+ "num_vpt": num_vpt,
137
+ "vpt_drop": vpt_drop,
138
+ "input_size": input_size,
139
+ "norm": norm,
140
+ "act": act
141
+ }
142
+
143
+ model.config = model_config
144
+ model_info = {"config": model_config, "weights": weights}
145
+
146
+ if weights is not None:
147
+ model.load_state_dict(weights)
148
+
149
+ if not os.path.exists(model_info_path):
150
+ torch.save(model_info, model_info_path)
151
+
152
+ return model
153
+
154
+
155
+ __all__ = ["get_model"]
models/clip_ebc/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .model import CLIP_EBC, _clip_ebc
2
+
3
+
4
+ __all__ = [
5
+ "CLIP_EBC",
6
+ "_clip_ebc",
7
+ ]
models/clip_ebc/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (264 Bytes). View file
 
models/clip_ebc/__pycache__/convnext.cpython-312.pyc ADDED
Binary file (8.36 kB). View file
 
models/clip_ebc/__pycache__/mobileclip.cpython-312.pyc ADDED
Binary file (7.82 kB). View file
 
models/clip_ebc/__pycache__/model.cpython-312.pyc ADDED
Binary file (12.8 kB). View file
 
models/clip_ebc/__pycache__/resnet.cpython-312.pyc ADDED
Binary file (9.74 kB). View file
 
models/clip_ebc/__pycache__/utils.cpython-312.pyc ADDED
Binary file (9.93 kB). View file
 
models/clip_ebc/__pycache__/vit.cpython-312.pyc ADDED
Binary file (16.7 kB). View file
 
models/clip_ebc/__pycache__/vit_siglip.cpython-312.pyc ADDED
Binary file (13.5 kB). View file
 
models/clip_ebc/convnext.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn, Tensor
2
+ import open_clip
3
+ from peft import get_peft_model, LoraConfig
4
+
5
+ from ..utils import ConvRefine, ConvAdapter
6
+ from ..utils import ConvUpsample, _get_norm_layer, _get_activation
7
+
8
+
9
+ convnext_names_and_weights = {
10
+ "convnext_base": ["laion400m_s13b_b51k"], # 107.49M
11
+ "convnext_base_w": ["laion2b_s13b_b82k", "laion2b_s13b_b82k_augreg", "laion_aesthetic_s13b_b82k"], # 107.75M
12
+ "convnext_base_w_320": ["laion_aesthetic_s13b_b82k", "laion_aesthetic_s13b_b82k_augreg"], # 107.75M
13
+ "convnext_large_d": ["laion2b_s26b_b102k_augreg"], # 217.46M
14
+ "convnext_large_d_320": ["laion2b_s29b_b131k_ft", "laion2b_s29b_b131k_ft_soup"], # 217.46M
15
+ "convnext_xxlarge": ["laion2b_s34b_b82k_augreg", "laion2b_s34b_b82k_augreg_rewind", "laion2b_s34b_b82k_augreg_soup"] # 896.88M
16
+ }
17
+
18
+ refiner_channels = {
19
+ "convnext_base": 1024,
20
+ "convnext_base_w": 1024,
21
+ "convnext_base_w_320": 1024,
22
+ "convnext_large_d": 1536,
23
+ "convnext_large_d_320": 1536,
24
+ "convnext_xxlarge": 3072,
25
+ }
26
+
27
+ refiner_groups = {
28
+ "convnext_base": 1,
29
+ "convnext_base_w": 1,
30
+ "convnext_base_w_320": 1,
31
+ "convnext_large_d": refiner_channels["convnext_large_d"] // 512, # 3
32
+ "convnext_large_d_320": refiner_channels["convnext_large_d_320"] // 512, # 3
33
+ "convnext_xxlarge": refiner_channels["convnext_xxlarge"] // 512, # 6
34
+ }
35
+
36
+
37
+
38
+ class ConvNeXt(nn.Module):
39
+ def __init__(
40
+ self,
41
+ model_name: str,
42
+ weight_name: str,
43
+ block_size: int = 16,
44
+ adapter: bool = False,
45
+ adapter_reduction: int = 4,
46
+ norm: str = "none",
47
+ act: str = "none"
48
+ ) -> None:
49
+ super(ConvNeXt, self).__init__()
50
+ assert model_name in convnext_names_and_weights, f"Model name should be one of {list(convnext_names_and_weights.keys())}, but got {model_name}."
51
+ assert weight_name in convnext_names_and_weights[model_name], f"Pretrained should be one of {convnext_names_and_weights[model_name]}, but got {weight_name}."
52
+ assert block_size in [32, 16, 8], f"block_size should be one of [32, 16, 8], got {block_size}"
53
+ self.model_name, self.weight_name = model_name, weight_name
54
+ self.block_size = block_size
55
+
56
+ model = open_clip.create_model_from_pretrained(model_name, weight_name, return_transform=False).visual
57
+
58
+ self.adapter = adapter
59
+ if adapter:
60
+ self.adapter_reduction = adapter_reduction
61
+ for param in model.parameters():
62
+ param.requires_grad = False
63
+
64
+ self.stem = model.trunk.stem
65
+ self.depth = len(model.trunk.stages)
66
+ for idx, stage in enumerate(model.trunk.stages):
67
+ setattr(self, f"stage{idx}", stage)
68
+ if adapter:
69
+ setattr(self, f"adapter{idx}", ConvAdapter(
70
+ in_channels=stage.blocks[-1].mlp.fc2.out_features,
71
+ bottleneck_channels=stage.blocks[-1].mlp.fc2.out_features // adapter_reduction,
72
+ ) if idx < self.depth - 1 else nn.Identity()) # No adapter for the last stage
73
+
74
+ if self.model_name in ["convnext_base", "convnext_base_w", "convnext_base_w_320", "convnext_xxlarge"]:
75
+ self.in_features, self.out_features = model.head.proj.in_features, model.head.proj.out_features
76
+ else: # "convnext_large_d", "convnext_large_d_320":
77
+ self.in_features, self.out_features = model.head.mlp.fc1.in_features, model.head.mlp.fc2.out_features
78
+
79
+ if norm == "bn":
80
+ norm_layer = nn.BatchNorm2d
81
+ elif norm == "ln":
82
+ norm_layer = nn.LayerNorm
83
+ else:
84
+ norm_layer = _get_norm_layer(model)
85
+
86
+ if act == "relu":
87
+ activation = nn.ReLU(inplace=True)
88
+ elif act == "gelu":
89
+ activation = nn.GELU()
90
+ else:
91
+ activation = _get_activation(model)
92
+
93
+ if block_size == 32:
94
+ self.refiner = ConvRefine(
95
+ in_channels=self.in_features,
96
+ out_channels=self.in_features,
97
+ norm_layer=norm_layer,
98
+ activation=activation,
99
+ groups=refiner_groups[self.model_name],
100
+ )
101
+ elif block_size == 16:
102
+ self.refiner = ConvUpsample(
103
+ in_channels=self.in_features,
104
+ out_channels=self.in_features,
105
+ norm_layer=norm_layer,
106
+ activation=activation,
107
+ groups=refiner_groups[self.model_name],
108
+ )
109
+ else: # block_size == 8
110
+ self.refiner = nn.Sequential(
111
+ ConvUpsample(
112
+ in_channels=self.in_features,
113
+ out_channels=self.in_features,
114
+ norm_layer=norm_layer,
115
+ activation=activation,
116
+ groups=refiner_groups[self.model_name],
117
+ ),
118
+ ConvUpsample(
119
+ in_channels=self.in_features,
120
+ out_channels=self.in_features,
121
+ norm_layer=norm_layer,
122
+ activation=activation,
123
+ groups=refiner_groups[self.model_name],
124
+ ),
125
+ )
126
+
127
+ def train(self, mode: bool = True):
128
+ if self.adapter and mode:
129
+ # training:
130
+ self.stem.eval()
131
+
132
+ for idx in range(self.depth):
133
+ getattr(self, f"stage{idx}").eval()
134
+ getattr(self, f"adapter{idx}").train()
135
+
136
+ self.refiner.train()
137
+
138
+ else:
139
+ # evaluation:
140
+ for module in self.children():
141
+ module.train(mode)
142
+
143
+ def forward(self, x: Tensor) -> Tensor:
144
+ x = self.stem(x)
145
+
146
+ for idx in range(self.depth):
147
+ x = getattr(self, f"stage{idx}")(x)
148
+ if self.adapter:
149
+ x = getattr(self, f"adapter{idx}")(x)
150
+
151
+ x = self.refiner(x)
152
+ return x
153
+
154
+
155
+ def _convnext(
156
+ model_name: str,
157
+ weight_name: str,
158
+ block_size: int = 16,
159
+ adapter: bool = False,
160
+ adapter_reduction: int = 4,
161
+ lora: bool = False,
162
+ lora_rank: int = 16,
163
+ lora_alpha: float = 32.0,
164
+ lora_dropout: float = 0.1,
165
+ norm: str = "none",
166
+ act: str = "none"
167
+ ) -> ConvNeXt:
168
+ assert not (lora and adapter), "Lora and adapter cannot be used together."
169
+ model = ConvNeXt(
170
+ model_name=model_name,
171
+ weight_name=weight_name,
172
+ block_size=block_size,
173
+ adapter=adapter,
174
+ adapter_reduction=adapter_reduction,
175
+ norm=norm,
176
+ act=act
177
+ )
178
+
179
+ if lora:
180
+ target_modules = []
181
+ for name, module in model.named_modules():
182
+ if isinstance(module, (nn.Linear, nn.Conv2d)) and "refiner" not in name:
183
+ target_modules.append(name)
184
+
185
+ lora_config = LoraConfig(
186
+ r=lora_rank,
187
+ lora_alpha=lora_alpha,
188
+ lora_dropout=lora_dropout,
189
+ bias="none",
190
+ target_modules=target_modules,
191
+ )
192
+ model = get_peft_model(model, lora_config)
193
+
194
+ # Unfreeze refiner
195
+ for name, module in model.named_modules():
196
+ if "refiner" in name:
197
+ module.requires_grad_(True)
198
+
199
+ return model
models/clip_ebc/mobileclip.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn, Tensor
2
+ import open_clip
3
+ from peft import get_peft_model, LoraConfig
4
+
5
+ from ..utils import ConvRefine, ConvUpsample, ConvAdapter
6
+ from ..utils import _get_norm_layer, _get_activation
7
+
8
+
9
+ mobileclip_names_and_weights = {
10
+ "MobileCLIP-S1": ["datacompdr"],
11
+ "MobileCLIP-S2": ["datacompdr"],
12
+ }
13
+
14
+
15
+ refiner_channels = {
16
+ "MobileCLIP-S1": 1024,
17
+ "MobileCLIP-S2": 1280,
18
+ }
19
+
20
+ refiner_groups = {
21
+ "MobileCLIP-S1": 2,
22
+ "MobileCLIP-S2": 2,
23
+ }
24
+
25
+
26
+ class MobileCLIP(nn.Module):
27
+ def __init__(
28
+ self,
29
+ model_name: str,
30
+ weight_name: str,
31
+ block_size: int = 16,
32
+ adapter: bool = False,
33
+ adapter_reduction: int = 4,
34
+ norm: str = "none",
35
+ act: str = "none"
36
+ ) -> None:
37
+ super().__init__()
38
+ assert model_name in mobileclip_names_and_weights, f"Model name should be one of {list(mobileclip_names_and_weights.keys())}, but got {model_name}."
39
+ assert weight_name in mobileclip_names_and_weights[model_name], f"Pretrained should be one of {mobileclip_names_and_weights[model_name]}, but got {weight_name}."
40
+ assert block_size in [32, 16, 8], f"block_size should be one of [32, 16, 8], got {block_size}"
41
+ self.model_name, self.weight_name = model_name, weight_name
42
+ self.block_size = block_size
43
+
44
+ model = open_clip.create_model_from_pretrained(model_name, weight_name, return_transform=False).visual
45
+
46
+ self.adapter = adapter
47
+ if adapter:
48
+ for param in model.parameters():
49
+ param.requires_grad = False
50
+
51
+ self.stem = model.trunk.stem
52
+ self.stages = model.trunk.stages
53
+
54
+ self.depth = len(model.trunk.stages)
55
+ for idx, stage in enumerate(model.trunk.stages):
56
+ if adapter:
57
+ setattr(self, f"adapter{idx}", ConvAdapter(
58
+ in_channels=stage.blocks[-1].mlp.fc2.out_channels,
59
+ bottleneck_channels=stage.blocks[-1].mlp.fc2.out_channels // adapter_reduction,
60
+ ))
61
+
62
+ self.final_conv = model.trunk.final_conv
63
+
64
+ self.in_features, self.out_features = model.trunk.head.fc.in_features, model.trunk.head.fc.out_features
65
+
66
+ # refine_block = LightConvRefine if model_name == "MobileCLIP-S1" else ConvRefine
67
+ # upsample_block = LightConvUpsample if model_name == "MobileCLIP-S1" else ConvUpsample
68
+
69
+ if norm == "bn":
70
+ norm_layer = nn.BatchNorm2d
71
+ elif norm == "ln":
72
+ norm_layer = nn.LayerNorm
73
+ else:
74
+ norm_layer = _get_norm_layer(model)
75
+
76
+ if act == "relu":
77
+ activation = nn.ReLU(inplace=True)
78
+ elif act == "gelu":
79
+ activation = nn.GELU()
80
+ else:
81
+ activation = _get_activation(model)
82
+
83
+ if block_size == 32:
84
+ self.refiner = ConvRefine(
85
+ in_channels=self.in_features,
86
+ out_channels=self.in_features,
87
+ norm_layer=norm_layer,
88
+ activation=activation,
89
+ groups=refiner_groups[model_name],
90
+ )
91
+ elif block_size == 16:
92
+ self.refiner = ConvUpsample(
93
+ in_channels=self.in_features,
94
+ out_channels=self.in_features,
95
+ norm_layer=norm_layer,
96
+ activation=activation,
97
+ groups=refiner_groups[self.model_name],
98
+ )
99
+ else: # block_size == 8
100
+ self.refiner = nn.Sequential(
101
+ ConvUpsample(
102
+ in_channels=self.in_features,
103
+ out_channels=self.in_features,
104
+ norm_layer=norm_layer,
105
+ activation=activation,
106
+ groups=refiner_groups[self.model_name],
107
+ ),
108
+ ConvUpsample(
109
+ in_channels=self.in_features,
110
+ out_channels=self.in_features,
111
+ norm_layer=norm_layer,
112
+ activation=activation,
113
+ groups=refiner_groups[self.model_name],
114
+ ),
115
+ )
116
+
117
+ def train(self, mode: bool = True):
118
+ if self.adapter and mode:
119
+ # training:
120
+ self.stem.eval()
121
+
122
+ for idx in range(self.depth):
123
+ getattr(self, f"stage{idx}").eval()
124
+ getattr(self, f"adapter{idx}").train()
125
+
126
+ self.final_conv.eval()
127
+ self.refiner.train()
128
+
129
+ else:
130
+ # evaluation:
131
+ for module in self.children():
132
+ module.train(mode)
133
+
134
+ def forward(self, x: Tensor) -> Tensor:
135
+ x = self.stem(x)
136
+
137
+ for idx in range(self.depth):
138
+ x = self.stages[idx](x)
139
+ if self.adapter:
140
+ x = getattr(self, f"adapter{idx}")(x)
141
+
142
+ x = self.final_conv(x)
143
+
144
+ x = self.refiner(x)
145
+ return x
146
+
147
+
148
+ def _mobileclip(
149
+ model_name: str,
150
+ weight_name: str,
151
+ block_size: int = 16,
152
+ adapter: bool = False,
153
+ adapter_reduction: int = 4,
154
+ lora: bool = False,
155
+ lora_rank: int = 16,
156
+ lora_alpha: float = 32.0,
157
+ lora_dropout: float = 0.1,
158
+ norm: str = "none",
159
+ act: str = "none"
160
+ ) -> MobileCLIP:
161
+ assert not (lora and adapter), "Lora and adapter cannot be used together."
162
+ model = MobileCLIP(
163
+ model_name=model_name,
164
+ weight_name=weight_name,
165
+ block_size=block_size,
166
+ adapter=adapter,
167
+ adapter_reduction=adapter_reduction,
168
+ norm=norm,
169
+ act=act
170
+ )
171
+
172
+ if lora:
173
+ target_modules = []
174
+ for name, module in model.named_modules():
175
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
176
+ target_modules.append(name)
177
+
178
+ lora_config = LoraConfig(
179
+ r=lora_rank,
180
+ lora_alpha=lora_alpha,
181
+ lora_dropout=lora_dropout,
182
+ bias="none",
183
+ target_modules=target_modules,
184
+ )
185
+ model = get_peft_model(model, lora_config)
186
+
187
+ # Unfreeze the BN layers
188
+ for name, module in model.named_modules() and "refiner" not in name:
189
+ if isinstance(module, nn.BatchNorm2d):
190
+ module.requires_grad_(True)
191
+
192
+ # Unfreeze refiner
193
+ for name, module in model.named_modules():
194
+ if "refiner" in name:
195
+ module.requires_grad_(True)
196
+
197
+ return model
models/clip_ebc/model.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn, Tensor
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ from typing import List, Optional, Dict, Tuple
6
+ from copy import deepcopy
7
+
8
+ from .vit import vit_names_and_weights, _vit
9
+ from .convnext import convnext_names_and_weights, _convnext
10
+ from .resnet import resnet_names_and_weights, _resnet
11
+ from .mobileclip import mobileclip_names_and_weights, _mobileclip
12
+
13
+ from .utils import encode_text, optimize_text_prompts
14
+ from ..utils import conv1x1
15
+
16
+ supported_models_and_weights = deepcopy(vit_names_and_weights)
17
+ supported_models_and_weights.update(convnext_names_and_weights)
18
+ supported_models_and_weights.update(resnet_names_and_weights)
19
+ supported_models_and_weights.update(mobileclip_names_and_weights)
20
+
21
+
22
+ class CLIP_EBC(nn.Module):
23
+ def __init__(
24
+ self,
25
+ model_name: str,
26
+ weight_name: str,
27
+ block_size: Optional[int] = None,
28
+ bins: Optional[List[Tuple[float, float]]] = None,
29
+ bin_centers: Optional[List[float]] = None,
30
+ zero_inflated: Optional[bool] = True,
31
+ num_vpt: Optional[int] = None,
32
+ vpt_drop: Optional[float] = None,
33
+ input_size: Optional[int] = None,
34
+ adapter: Optional[bool] = False,
35
+ adapter_reduction: Optional[int] = None,
36
+ lora: Optional[bool] = False,
37
+ lora_rank: Optional[int] = None,
38
+ lora_alpha: Optional[float] = None,
39
+ lora_dropout: Optional[float] = None,
40
+ text_prompts: Optional[Dict[str, List[str]]] = None,
41
+ norm: Optional[str] = "none",
42
+ act: Optional[str] = "none",
43
+ ) -> None:
44
+ super().__init__()
45
+ if "mobileclip" in model_name.lower() or "vit" in model_name.lower():
46
+ model_name = model_name.replace("_", "-")
47
+ assert model_name in supported_models_and_weights, f"Model name should be one of {list(supported_models_and_weights.keys())}, but got {model_name}."
48
+ assert weight_name in supported_models_and_weights[model_name], f"Pretrained should be one of {supported_models_and_weights[model_name]}, but got {weight_name}."
49
+ assert len(bins) == len(bin_centers), f"Expected bins and bin_centers to have the same length, got {len(bins)} and {len(bin_centers)}"
50
+ assert len(bins) >= 2, f"Expected at least 2 bins, got {len(bins)}"
51
+ assert all(len(b) == 2 for b in bins), f"Expected bins to be a list of tuples of length 2, got {bins}"
52
+ bins = [(float(b[0]), float(b[1])) for b in bins]
53
+ assert all(bin[0] <= p <= bin[1] for bin, p in zip(bins, bin_centers)), f"Expected bin_centers to be within the range of the corresponding bin, got {bins} and {bin_centers}"
54
+
55
+ self.model_name = model_name
56
+ self.weight_name = weight_name
57
+ self.block_size = block_size
58
+ self.bins = bins
59
+ self.register_buffer("bin_centers", torch.tensor(bin_centers, dtype=torch.float32, requires_grad=False).view(1, -1, 1, 1))
60
+ self.zero_inflated = zero_inflated
61
+ self.text_prompts = text_prompts
62
+
63
+ # Image encoder
64
+ if model_name in vit_names_and_weights:
65
+ assert num_vpt is not None and num_vpt >= 0, f"Number of VPT tokens should be greater than 0, but got {num_vpt}."
66
+ vpt_drop = 0. if vpt_drop is None else vpt_drop
67
+ self.backbone = _vit(
68
+ model_name=model_name,
69
+ weight_name=weight_name,
70
+ num_vpt=num_vpt,
71
+ vpt_drop=vpt_drop,
72
+ block_size=block_size,
73
+ adapter=adapter,
74
+ adapter_reduction=adapter_reduction,
75
+ lora=lora,
76
+ lora_rank=lora_rank,
77
+ lora_alpha=lora_alpha,
78
+ lora_dropout=lora_dropout,
79
+ input_size=(input_size, input_size),
80
+ norm=norm,
81
+ act=act
82
+ )
83
+ elif model_name in convnext_names_and_weights:
84
+ self.backbone = _convnext(
85
+ model_name=model_name,
86
+ weight_name=weight_name,
87
+ block_size=block_size,
88
+ adapter=adapter,
89
+ adapter_reduction=adapter_reduction,
90
+ lora=lora,
91
+ lora_rank=lora_rank,
92
+ lora_alpha=lora_alpha,
93
+ lora_dropout=lora_dropout,
94
+ norm=norm,
95
+ act=act
96
+ )
97
+ elif model_name in resnet_names_and_weights:
98
+ self.backbone = _resnet(
99
+ model_name=model_name,
100
+ weight_name=weight_name,
101
+ block_size=block_size,
102
+ adapter=adapter,
103
+ adapter_reduction=adapter_reduction,
104
+ lora=lora,
105
+ lora_rank=lora_rank,
106
+ lora_alpha=lora_alpha,
107
+ lora_dropout=lora_dropout,
108
+ norm=norm,
109
+ act=act
110
+ )
111
+ elif model_name in mobileclip_names_and_weights:
112
+ self.backbone = _mobileclip(
113
+ model_name=model_name,
114
+ weight_name=weight_name,
115
+ block_size=block_size,
116
+ adapter=adapter,
117
+ adapter_reduction=adapter_reduction,
118
+ lora=lora,
119
+ lora_rank=lora_rank,
120
+ lora_alpha=lora_alpha,
121
+ lora_dropout=lora_dropout,
122
+ norm=norm,
123
+ act=act
124
+ )
125
+
126
+ self._build_text_feats()
127
+ self._build_head()
128
+
129
+ def _build_text_feats(self) -> None:
130
+ model_name, weight_name = self.model_name, self.weight_name
131
+ text_prompts = self.text_prompts
132
+
133
+ if text_prompts is None:
134
+ bins = [b[0] if b[0] == b[1] else b for b in self.bins] # if the bin is a single value (e.g., [0, 0]), use that value
135
+ if self.zero_inflated: # separate 0 from the rest
136
+ assert bins[0] == 0, f"Expected the first bin to be 0, got {bins[0]}."
137
+ bins_pi = [0, (1, float("inf"))]
138
+ bins_lambda = bins[1:]
139
+ pi_text_prompts = optimize_text_prompts(model_name, weight_name, bins_pi)
140
+ lambda_text_prompts = optimize_text_prompts(model_name, weight_name, bins_lambda)
141
+ self.text_prompts = {"pi": pi_text_prompts, "lambda": lambda_text_prompts}
142
+ pi_text_feats = encode_text(model_name, weight_name, pi_text_prompts)
143
+ lambda_text_feats = encode_text(model_name, weight_name, lambda_text_prompts)
144
+ pi_text_feats.requires_grad = False
145
+ lambda_text_feats.requires_grad = False
146
+ self.register_buffer("pi_text_feats", pi_text_feats)
147
+ self.register_buffer("lambda_text_feats", lambda_text_feats)
148
+
149
+ else:
150
+ text_prompts = optimize_text_prompts(model_name, weight_name, bins)
151
+ self.text_prompts = text_prompts
152
+ text_feats = encode_text(model_name, weight_name, text_prompts)
153
+ text_feats.requires_grad = False
154
+ self.register_buffer("text_feats", text_feats)
155
+
156
+ else:
157
+ if self.zero_inflated:
158
+ assert "pi" in text_prompts and "lambda" in text_prompts, f"Expected text_prompts to have keys 'pi' and 'lambda', got {text_prompts.keys()}."
159
+ pi_text_prompts = text_prompts["pi"]
160
+ lambda_text_prompts = text_prompts["lambda"]
161
+ pi_text_feats = encode_text(model_name, weight_name, pi_text_prompts)
162
+ lambda_text_feats = encode_text(model_name, weight_name, lambda_text_prompts)
163
+ pi_text_feats.requires_grad = False
164
+ lambda_text_feats.requires_grad = False
165
+ self.register_buffer("pi_text_feats", pi_text_feats)
166
+ self.register_buffer("lambda_text_feats", lambda_text_feats)
167
+
168
+ else:
169
+ text_feats = encode_text(model_name, weight_name, text_prompts)
170
+ text_feats.requires_grad = False
171
+ self.register_buffer("text_feats", text_feats)
172
+
173
+ def _build_head(self) -> None:
174
+ in_channels = self.backbone.in_features
175
+ out_channels = self.backbone.out_features
176
+ if self.zero_inflated:
177
+ self.pi_logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07), requires_grad=True)
178
+ self.lambda_logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07), requires_grad=True)
179
+
180
+ self.pi_head = conv1x1(in_channels, out_channels, bias=False)
181
+ self.lambda_head = conv1x1(in_channels, out_channels, bias=False)
182
+
183
+ else:
184
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07), requires_grad=True)
185
+ self.head = conv1x1(in_channels, out_channels, bias=False)
186
+
187
+ def forward(self, image: Tensor):
188
+ image_feats = self.backbone(image)
189
+ # image_feats = F.normalize(image_feats.permute(0, 2, 3, 1), p=2, dim=-1) # shape (B, H, W, C)
190
+
191
+ if self.zero_inflated:
192
+ pi_image_feats, lambda_image_feats = self.pi_head(image_feats), self.lambda_head(image_feats)
193
+ pi_image_feats = F.normalize(pi_image_feats.permute(0, 2, 3, 1), p=2, dim=-1) # shape (B, H, W, C)
194
+ lambda_image_feats = F.normalize(lambda_image_feats.permute(0, 2, 3, 1), p=2, dim=-1) # shape (B, H, W, C)
195
+
196
+ pi_text_feats, lambda_text_feats = self.pi_text_feats, self.lambda_text_feats
197
+ pi_logit_scale, lambda_logit_scale = self.pi_logit_scale.exp(), self.lambda_logit_scale.exp()
198
+
199
+ pi_logit_map = pi_logit_scale * pi_image_feats @ pi_text_feats.t() # (B, H, W, 2), logits per image
200
+ lambda_logit_map = lambda_logit_scale * lambda_image_feats @ lambda_text_feats.t() # (B, H, W, N - 1), logits per image
201
+
202
+ pi_logit_map = pi_logit_map.permute(0, 3, 1, 2) # (B, 2, H, W)
203
+ lambda_logit_map = lambda_logit_map.permute(0, 3, 1, 2) # (B, N - 1, H, W)
204
+
205
+ lambda_map = (lambda_logit_map.softmax(dim=1) * self.bin_centers[:, 1:]).sum(dim=1, keepdim=True) # (B, 1, H, W)
206
+
207
+ # pi_logit_map.softmax(dim=1)[:, 0] is the probability of zeros
208
+ den_map = pi_logit_map.softmax(dim=1)[:, 1:] * lambda_map # (B, 1, H, W)
209
+
210
+ if self.training:
211
+ return pi_logit_map, lambda_logit_map, lambda_map, den_map
212
+ else:
213
+ return den_map
214
+
215
+ else:
216
+ image_feats = self.head(image_feats)
217
+ image_feats = F.normalize(image_feats.permute(0, 2, 3, 1), p=2, dim=-1)
218
+
219
+ text_feats = self.text_feats
220
+ logit_scale = self.logit_scale.exp()
221
+
222
+ logit_map = logit_scale * image_feats @ text_feats.t() # (B, H, W, N), logits per image
223
+ logit_map = logit_map.permute(0, 3, 1, 2) # (B, N, H, W)
224
+
225
+ den_map = (logit_map.softmax(dim=1) * self.bin_centers).sum(dim=1, keepdim=True) # (B, 1, H, W)
226
+
227
+ if self.training:
228
+ return logit_map, den_map
229
+ else:
230
+ return den_map
231
+
232
+
233
+ def _clip_ebc(
234
+ model_name: str,
235
+ weight_name: str,
236
+ block_size: Optional[int] = None,
237
+ bins: Optional[List[Tuple[float, float]]] = None,
238
+ bin_centers: Optional[List[float]] = None,
239
+ zero_inflated: Optional[bool] = True,
240
+ num_vpt: Optional[int] = None,
241
+ vpt_drop: Optional[float] = None,
242
+ input_size: Optional[int] = None,
243
+ adapter: Optional[bool] = False,
244
+ adapter_reduction: Optional[int] = None,
245
+ lora: Optional[bool] = False,
246
+ lora_rank: Optional[int] = None,
247
+ lora_alpha: Optional[float] = None,
248
+ lora_dropout: Optional[float] = None,
249
+ text_prompts: Optional[List[str]] = None,
250
+ norm: Optional[str] = "none",
251
+ act: Optional[str] = "none",
252
+ ) -> CLIP_EBC:
253
+ return CLIP_EBC(
254
+ model_name=model_name,
255
+ weight_name=weight_name,
256
+ block_size=block_size,
257
+ bins=bins,
258
+ bin_centers=bin_centers,
259
+ zero_inflated=zero_inflated,
260
+ num_vpt=num_vpt,
261
+ vpt_drop=vpt_drop,
262
+ input_size=input_size,
263
+ adapter=adapter,
264
+ adapter_reduction=adapter_reduction,
265
+ lora=lora,
266
+ lora_rank=lora_rank,
267
+ lora_alpha=lora_alpha,
268
+ lora_dropout=lora_dropout,
269
+ text_prompts=text_prompts,
270
+ norm=norm,
271
+ act=act,
272
+ )
models/clip_ebc/resnet.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn, Tensor
2
+ import open_clip
3
+ from peft import get_peft_model, LoraConfig
4
+
5
+ from ..utils import ConvRefine, ConvUpsample, ConvAdapter
6
+ from ..utils import _get_norm_layer, _get_activation
7
+
8
+
9
+ resnet_names_and_weights = {
10
+ "RN50": ["openai", "yfcc15m", "cc12m"],
11
+ "RN101": ["openai", "yfcc15m", "cc12m"],
12
+ "RN50x4": ["openai", "yfcc15m", "cc12m"],
13
+ "RN50x16": ["openai", "yfcc15m", "cc12m"],
14
+ "RN50x64": ["openai", "yfcc15m", "cc12m"],
15
+ }
16
+
17
+ refiner_channels = {
18
+ "RN50": 2048,
19
+ "RN101": 2048,
20
+ "RN50x4": 2560,
21
+ "RN50x16": 3072,
22
+ "RN50x64": 4096,
23
+ }
24
+
25
+ refiner_groups = {
26
+ "RN50": refiner_channels["RN50"] // 512, # 4
27
+ "RN101": refiner_channels["RN101"] // 512, # 4
28
+ "RN50x4": refiner_channels["RN50x4"] // 512, # 5
29
+ "RN50x16": refiner_channels["RN50x16"] // 512, # 6
30
+ "RN50x64": refiner_channels["RN50x64"] // 512, # 8
31
+ }
32
+
33
+
34
+ class ResNet(nn.Module):
35
+ def __init__(
36
+ self,
37
+ model_name: str,
38
+ weight_name: str,
39
+ block_size: int = 16,
40
+ adapter: bool = False,
41
+ adapter_reduction: int = 4,
42
+ norm: str = "none",
43
+ act: str = "none"
44
+ ) -> None:
45
+ super(ResNet, self).__init__()
46
+ assert model_name in resnet_names_and_weights, f"Model name should be one of {list(resnet_names_and_weights.keys())}, but got {model_name}."
47
+ assert weight_name in resnet_names_and_weights[model_name], f"Pretrained should be one of {resnet_names_and_weights[model_name]}, but got {weight_name}."
48
+ assert block_size in [32, 16, 8], f"block_size should be one of [32, 16, 8], got {block_size}"
49
+ self.model_name, self.weight_name = model_name, weight_name
50
+ self.block_size = block_size
51
+
52
+ model = open_clip.create_model_from_pretrained(model_name, weight_name, return_transform=False).visual
53
+
54
+ self.adapter = adapter
55
+ if adapter:
56
+ for param in model.parameters():
57
+ param.requires_grad = False
58
+
59
+ # Stem
60
+ self.conv1 = model.conv1
61
+ self.bn1 = model.bn1
62
+ self.act1 = model.act1
63
+ self.conv2 = model.conv2
64
+ self.bn2 = model.bn2
65
+ self.act2 = model.act2
66
+ self.conv3 = model.conv3
67
+ self.bn3 = model.bn3
68
+ self.act3 = model.act3
69
+ self.avgpool = model.avgpool
70
+ # Stem: reduction = 4
71
+
72
+ # Layers
73
+ for idx in range(1, 5):
74
+ setattr(self, f"layer{idx}", getattr(model, f"layer{idx}"))
75
+ if adapter:
76
+ setattr(self, f"adapter{idx}", ConvAdapter(
77
+ in_channels=getattr(model, f"layer{idx}")[-1].conv3.out_channels,
78
+ bottleneck_channels=getattr(model, f"layer{idx}")[-1].conv3.out_channels // adapter_reduction,
79
+ ) if idx < 4 else nn.Identity()) # No adapter for the last layer
80
+
81
+ self.in_features = model.attnpool.c_proj.weight.shape[1]
82
+ self.out_features = model.attnpool.c_proj.weight.shape[0]
83
+
84
+ if norm == "bn":
85
+ norm_layer = nn.BatchNorm2d
86
+ elif norm == "ln":
87
+ norm_layer = nn.LayerNorm
88
+ else:
89
+ norm_layer = _get_norm_layer(model)
90
+
91
+ if act == "relu":
92
+ activation = nn.ReLU(inplace=True)
93
+ elif act == "gelu":
94
+ activation = nn.GELU()
95
+ else:
96
+ activation = _get_activation(model)
97
+
98
+ if block_size == 32:
99
+ self.refiner = ConvRefine(
100
+ in_channels=self.in_features,
101
+ out_channels=self.in_features,
102
+ norm_layer=norm_layer,
103
+ activation=activation,
104
+ groups=refiner_groups[self.model_name],
105
+ )
106
+ elif block_size == 16:
107
+ self.refiner = ConvUpsample(
108
+ in_channels=self.in_features,
109
+ out_channels=self.in_features,
110
+ norm_layer=norm_layer,
111
+ activation=activation,
112
+ groups=refiner_groups[self.model_name],
113
+ )
114
+ else: # block_size == 8
115
+ self.refiner = nn.Sequential(
116
+ ConvUpsample(
117
+ in_channels=self.in_features,
118
+ out_channels=self.in_features,
119
+ norm_layer=norm_layer,
120
+ activation=activation,
121
+ groups=refiner_groups[self.model_name],
122
+ ),
123
+ ConvUpsample(
124
+ in_channels=self.in_features,
125
+ out_channels=self.in_features,
126
+ norm_layer=norm_layer,
127
+ activation=activation,
128
+ groups=refiner_groups[self.model_name],
129
+ ),
130
+ )
131
+
132
+ def train(self, mode: bool = True):
133
+ if self.adapter and mode:
134
+ # training:
135
+ self.conv1.eval()
136
+ self.bn1.eval()
137
+ self.act1.eval()
138
+ self.conv2.eval()
139
+ self.bn2.eval()
140
+ self.act2.eval()
141
+ self.conv3.eval()
142
+ self.bn3.eval()
143
+ self.act3.eval()
144
+ self.avgpool.eval()
145
+
146
+ for idx in range(1, 5):
147
+ getattr(self, f"layer{idx}").eval()
148
+ getattr(self, f"adapter{idx}").train()
149
+
150
+ self.refiner.train()
151
+
152
+ else:
153
+ # evaluation:
154
+ for module in self.children():
155
+ module.train(mode)
156
+
157
+ def stem(self, x: Tensor) -> Tensor:
158
+ x = self.act1(self.bn1(self.conv1(x)))
159
+ x = self.act2(self.bn2(self.conv2(x)))
160
+ x = self.act3(self.bn3(self.conv3(x)))
161
+ x = self.avgpool(x)
162
+ return x
163
+
164
+ def forward(self, x: Tensor) -> Tensor:
165
+ x = self.stem(x)
166
+
167
+ x = self.layer1(x)
168
+ if self.adapter:
169
+ x = self.adapter1(x)
170
+
171
+ x = self.layer2(x)
172
+ if self.adapter:
173
+ x = self.adapter2(x)
174
+
175
+ x = self.layer3(x)
176
+ if self.adapter:
177
+ x = self.adapter3(x)
178
+
179
+ x = self.layer4(x)
180
+ if self.adapter:
181
+ x = self.adapter4(x)
182
+
183
+ x = self.refiner(x)
184
+ return x
185
+
186
+
187
+ def _resnet(
188
+ model_name: str,
189
+ weight_name: str,
190
+ block_size: int = 16,
191
+ adapter: bool = False,
192
+ adapter_reduction: int = 4,
193
+ lora: bool = False,
194
+ lora_rank: int = 16,
195
+ lora_alpha: float = 32.0,
196
+ lora_dropout: float = 0.1,
197
+ norm: str = "none",
198
+ act: str = "none"
199
+ ) -> ResNet:
200
+ assert not (lora and adapter), "Lora and adapter cannot be used together."
201
+ model = ResNet(
202
+ model_name=model_name,
203
+ weight_name=weight_name,
204
+ block_size=block_size,
205
+ adapter=adapter,
206
+ adapter_reduction=adapter_reduction,
207
+ norm=norm,
208
+ act=act
209
+ )
210
+
211
+ if lora:
212
+ target_modules = []
213
+ for name, module in model.named_modules():
214
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
215
+ target_modules.append(name)
216
+
217
+ lora_config = LoraConfig(
218
+ r=lora_rank,
219
+ lora_alpha=lora_alpha,
220
+ lora_dropout=lora_dropout,
221
+ bias="none",
222
+ target_modules=target_modules,
223
+ )
224
+ model = get_peft_model(model, lora_config)
225
+
226
+ # Unfreeze BN layers
227
+ for name, module in model.named_modules():
228
+ if isinstance(module, nn.BatchNorm2d) and "refiner" not in name:
229
+ module.requires_grad_(True)
230
+
231
+ # Unfreeze refiner
232
+ for name, module in model.named_modules():
233
+ if "refiner" in name:
234
+ module.requires_grad_(True)
235
+
236
+ return model
models/clip_ebc/utils.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import Tensor, nn
3
+ import torch.nn.functional as F
4
+ import open_clip
5
+ from tqdm import tqdm
6
+ import numpy as np
7
+ from typing import Union, Tuple, List
8
+
9
+
10
+ num_to_word = {
11
+ "0": "zero", "1": "one", "2": "two", "3": "three", "4": "four", "5": "five", "6": "six", "7": "seven", "8": "eight", "9": "nine",
12
+ "10": "ten", "11": "eleven", "12": "twelve", "13": "thirteen", "14": "fourteen", "15": "fifteen", "16": "sixteen", "17": "seventeen", "18": "eighteen", "19": "nineteen",
13
+ "20": "twenty", "21": "twenty-one", "22": "twenty-two", "23": "twenty-three", "24": "twenty-four", "25": "twenty-five", "26": "twenty-six", "27": "twenty-seven", "28": "twenty-eight", "29": "twenty-nine",
14
+ "30": "thirty", "31": "thirty-one", "32": "thirty-two", "33": "thirty-three", "34": "thirty-four", "35": "thirty-five", "36": "thirty-six", "37": "thirty-seven", "38": "thirty-eight", "39": "thirty-nine",
15
+ "40": "forty", "41": "forty-one", "42": "forty-two", "43": "forty-three", "44": "forty-four", "45": "forty-five", "46": "forty-six", "47": "forty-seven", "48": "forty-eight", "49": "forty-nine",
16
+ "50": "fifty", "51": "fifty-one", "52": "fifty-two", "53": "fifty-three", "54": "fifty-four", "55": "fifty-five", "56": "fifty-six", "57": "fifty-seven", "58": "fifty-eight", "59": "fifty-nine",
17
+ "60": "sixty", "61": "sixty-one", "62": "sixty-two", "63": "sixty-three", "64": "sixty-four", "65": "sixty-five", "66": "sixty-six", "67": "sixty-seven", "68": "sixty-eight", "69": "sixty-nine",
18
+ "70": "seventy", "71": "seventy-one", "72": "seventy-two", "73": "seventy-three", "74": "seventy-four", "75": "seventy-five", "76": "seventy-six", "77": "seventy-seven", "78": "seventy-eight", "79": "seventy-nine",
19
+ "80": "eighty", "81": "eighty-one", "82": "eighty-two", "83": "eighty-three", "84": "eighty-four", "85": "eighty-five", "86": "eighty-six", "87": "eighty-seven", "88": "eighty-eight", "89": "eighty-nine",
20
+ "90": "ninety", "91": "ninety-one", "92": "ninety-two", "93": "ninety-three", "94": "ninety-four", "95": "ninety-five", "96": "ninety-six", "97": "ninety-seven", "98": "ninety-eight", "99": "ninety-nine",
21
+ "100": "one hundred"
22
+ }
23
+
24
+ prefixes = [
25
+ "",
26
+ "A photo of", "A block of", "An image of", "A picture of",
27
+ "There are",
28
+ "The image contains", "The photo contains", "The picture contains",
29
+ "The image shows", "The photo shows", "The picture shows",
30
+ ]
31
+ arabic_numeral = [True, False]
32
+ compares = [
33
+ "more than", "greater than", "higher than", "larger than", "bigger than", "greater than or equal to",
34
+ "at least", "no less than", "not less than", "not fewer than", "not lower than", "not smaller than", "not less than or equal to",
35
+ "over", "above", "beyond", "exceeding", "surpassing",
36
+ ]
37
+ suffixes = [
38
+ "people", "persons", "individuals", "humans", "faces", "heads", "figures", "",
39
+ ]
40
+
41
+
42
+ def num2word(num: Union[int, str]) -> str:
43
+ """
44
+ Convert the input number to the corresponding English word. For example, 1 -> "one", 2 -> "two", etc.
45
+ """
46
+ num = str(int(num))
47
+ return num_to_word.get(num, num)
48
+
49
+
50
+ def format_count(
51
+ bins: List[Union[float, Tuple[float, float]]],
52
+ ) -> List[List[str]]:
53
+ text_prompts = []
54
+ for prefix in prefixes:
55
+ for numeral in arabic_numeral:
56
+ for compare in compares:
57
+ for suffix in suffixes:
58
+ prompts = []
59
+ for bin in bins:
60
+ if isinstance(bin, (int, float)): # count is a single number
61
+ count = int(bin)
62
+ if count == 0 or count == 1:
63
+ count = num2word(count) if not numeral else count
64
+ prefix_ = "There is" if prefix == "There are" else prefix
65
+ suffix_ = "person" if suffix == "people" else suffix[:-1]
66
+ prompt = f"{prefix_} {count} {suffix_}"
67
+ else: # count > 1
68
+ count = num2word(count) if not numeral else count
69
+ prompt = f"{prefix} {count} {suffix}"
70
+
71
+ elif bin[1] == float("inf"): # count is (lower_bound, inf)
72
+ count = int(bin[0])
73
+ count = num2word(count) if not numeral else count
74
+ prompt = f"{prefix} {compare} {count} {suffix}"
75
+
76
+ else: # bin is (lower_bound, upper_bound)
77
+ left, right = int(bin[0]), int(bin[1])
78
+ left, right = num2word(left) if not numeral else left, num2word(right) if not numeral else right
79
+ prompt = f"{prefix} between {left} and {right} {suffix}"
80
+
81
+ # Remove starting and trailing whitespaces
82
+ prompt = prompt.strip() + "."
83
+
84
+ prompts.append(prompt)
85
+
86
+ text_prompts.append(prompts)
87
+
88
+ return text_prompts
89
+
90
+
91
+ def encode_text(
92
+ model_name: str,
93
+ weight_name: str,
94
+ text: List[str]
95
+ ) -> Tensor:
96
+ if torch.cuda.is_available():
97
+ device = torch.device("cuda")
98
+ elif torch.mps.is_available():
99
+ device = torch.device("mps")
100
+ else:
101
+ device = torch.device("cpu")
102
+ text = open_clip.get_tokenizer(model_name)(text).to(device)
103
+ model = open_clip.create_model_from_pretrained(model_name, weight_name, return_transform=False).to(device)
104
+ model.eval()
105
+ with torch.no_grad():
106
+ text_feats = model.encode_text(text)
107
+ text_feats = F.normalize(text_feats, p=2, dim=-1).detach().cpu()
108
+ return text_feats
109
+
110
+
111
+ def optimize_text_prompts(
112
+ model_name: str,
113
+ weight_name: str,
114
+ flat_bins: List[Union[float, Tuple[float, float]]],
115
+ batch_size: int = 1024,
116
+ ) -> List[str]:
117
+ text_prompts = format_count(flat_bins)
118
+
119
+ # Find the template that has the smallest average similarity of bin prompts.
120
+ print("Finding the best setup for text prompts...")
121
+ text_prompts_ = [prompt for prompts in text_prompts for prompt in prompts] # flatten the list
122
+ text_feats = []
123
+ for i in tqdm(range(0, len(text_prompts_), batch_size)):
124
+ text_feats.append(encode_text(model_name, weight_name, text_prompts_[i: min(i + batch_size, len(text_prompts_))]))
125
+ text_feats = torch.cat(text_feats, dim=0)
126
+
127
+ sims = []
128
+ for idx, prompts in enumerate(text_prompts):
129
+ text_feats_ = text_feats[idx * len(prompts): (idx + 1) * len(prompts)]
130
+ sim = torch.mm(text_feats_, text_feats_.T)
131
+ sim = sim[~torch.eye(sim.shape[0], dtype=bool)].mean().item()
132
+ sims.append(sim)
133
+
134
+ optimal_prompts = text_prompts[np.argmin(sims)]
135
+ sim = sims[np.argmin(sims)]
136
+ print(f"Found the best text prompts: {optimal_prompts} (similarity: {sim:.2f})")
137
+ return optimal_prompts
models/clip_ebc/vit.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn, Tensor
3
+ import math
4
+ from einops import rearrange
5
+ import open_clip
6
+ from peft import get_peft_model, LoraConfig
7
+ from typing import Optional, Tuple
8
+
9
+ from ..utils import interpolate_pos_embed, ViTAdapter
10
+ # from ..utils import TransformerRefine, TransformerDownsample, TransformerUpsample
11
+ from ..utils import ConvRefine, ConvDownsample, ConvUpsample
12
+ from ..utils import _get_norm_layer, _get_activation
13
+
14
+
15
+ vit_names_and_weights = {
16
+ "ViT-B-32": [
17
+ "openai",
18
+ "laion400m_e31", "laion400m_e32", "laion2b_e16", "laion2b_s34b_b79k",
19
+ "datacomp_xl_s13b_b90k", "datacomp_m_s128m_b4k", "datacomp_s_s13m_b4k",
20
+ "commonpool_m_clip_s128m_b4k", "commonpool_m_laion_s128m_b4k", "commonpool_m_image_s128m_b4k", "commonpool_m_text_s128m_b4k", "commonpool_m_basic_s128m_b4k", "commonpool_m_s128m_b4k",
21
+ "commonpool_s_clip_s13m_b4k", "commonpool_s_laion_s13m_b4k", "commonpool_s_image_s13m_b4k", "commonpool_s_text_s13m_b4k", "commonpool_s_basic_s13m_b4k", "commonpool_s_s13m_b4k",
22
+ ],
23
+ "ViT-B_32-256": ["datacomp_s34b_b86k"],
24
+ "ViT-B-16": [
25
+ "openai",
26
+ "laion400m_e31", "laion400m_e32", "laion2b_s34b_b88k",
27
+ "datacomp_xl_s13b_b90k", "datacomp_l_s1b_b8k",
28
+ "commonpool_l_clip_s1b_b8k", "commonpool_l_laion_s1b_b8k", "commonpool_l_image_s1b_b8k", "commonpool_l_text_s1b_b8k", "commonpool_l_basic_s1b_b8k", "commonpool_l_s1b_b8k",
29
+ "dfn2b"
30
+ ],
31
+ "ViT-L-14": [
32
+ "openai",
33
+ "laion400m_e31", "laion400m_e32", "laion2b_s32b_b82k",
34
+ "datacomp_xl_s13b_b90k",
35
+ "commonpool_xl_clip_s13b_b90k", "commonpool_xl_laion_s13b_b90k", "commonpool_xl_s13b_b90k"
36
+ ],
37
+ "ViT-L-14-336": ["openai"],
38
+ "ViT-H-14": ["laion2b_s32b_b79k"],
39
+ "ViT-g-14": ["laion2b_s12b_b42k", "laion2b_s34b_b88k"],
40
+ "ViT-bigG-14": ["laion2b_s39b_b160k"],
41
+ }
42
+
43
+
44
+ refiner_channels = {
45
+ "ViT-B-32": 768,
46
+ "ViT-B-32-256": 768,
47
+ "ViT-B-16": 768,
48
+ "ViT-L-14": 1024,
49
+ "ViT-L-14-336": 1024,
50
+ "ViT-H-14": 1280,
51
+ "ViT-g-14": 1408,
52
+ "ViT-bigG-14": 1664,
53
+ }
54
+
55
+ refiner_groups = {
56
+ "ViT-B-32": 1,
57
+ "ViT-B-32-256": 1,
58
+ "ViT-B-16": 1,
59
+ "ViT-L-14": 1,
60
+ "ViT-L-14-336": 1,
61
+ "ViT-H-14": 1,
62
+ "ViT-g-14": refiner_channels["ViT-g-14"] // 704, # 2
63
+ "ViT-bigG-14": refiner_channels["ViT-bigG-14"] // 416, # 4
64
+ }
65
+
66
+
67
+
68
+ class ViT(nn.Module):
69
+ def __init__(
70
+ self,
71
+ model_name: str,
72
+ weight_name: str,
73
+ block_size: int = 16,
74
+ num_vpt: int = 32,
75
+ vpt_drop: float = 0.0,
76
+ adapter: bool = False,
77
+ adapter_reduction: int = 4,
78
+ input_size: Optional[Tuple[int, int]] = None,
79
+ norm: str = "none",
80
+ act: str = "none"
81
+ ) -> None:
82
+ super(ViT, self).__init__()
83
+ assert model_name in vit_names_and_weights, f"Model name should be one of {list(vit_names_and_weights.keys())}, but got {model_name}."
84
+ assert weight_name in vit_names_and_weights[model_name], f"Pretrained should be one of {vit_names_and_weights[model_name]}, but got {weight_name}."
85
+ if adapter:
86
+ assert num_vpt is None or num_vpt == 0, "num_vpt should be None or 0 when using adapter."
87
+ assert vpt_drop is None or vpt_drop == 0.0, "vpt_drop should be None or 0.0 when using adapter."
88
+ else:
89
+ assert num_vpt > 0, f"Number of VPT tokens should be greater than 0, but got {num_vpt}."
90
+ assert 0.0 <= vpt_drop < 1.0, f"VPT dropout should be in [0.0, 1.0), but got {vpt_drop}."
91
+
92
+ self.model_name, self.weight_name = model_name, weight_name
93
+ self.block_size = block_size
94
+ self.num_vpt = num_vpt
95
+ self.vpt_drop = vpt_drop
96
+ self.adapter = adapter
97
+
98
+ model = open_clip.create_model_from_pretrained(model_name, weight_name, return_transform=False).visual
99
+
100
+ # Always freeze the parameters of the model
101
+ for param in model.parameters():
102
+ param.requires_grad = False
103
+
104
+ # Setup the model
105
+ self.input_size = input_size if input_size is not None else model.image_size
106
+ self.pretrain_size = model.image_size
107
+ self.patch_size = model.patch_size
108
+ self.class_embedding = model.class_embedding
109
+ self.positional_embedding = model.positional_embedding
110
+ self.embed_dim = model.class_embedding.shape[-1]
111
+
112
+ self.conv1 = model.conv1
113
+ self.ln_pre = model.ln_pre
114
+ self.resblocks = model.transformer.resblocks
115
+ self.num_layers = len(self.resblocks)
116
+ self.ln_post = model.ln_post
117
+
118
+ # Setup VPT tokens
119
+ val = math.sqrt(6. / float(3 * self.patch_size[0] + self.embed_dim))
120
+ for idx in range(self.num_layers):
121
+ if self.adapter:
122
+ setattr(self, f"adapter{idx}", ViTAdapter(
123
+ in_channels=self.embed_dim,
124
+ bottleneck_channels=self.embed_dim // adapter_reduction,
125
+ ))
126
+ else:
127
+ setattr(self, f"vpt_{idx}", nn.Parameter(torch.empty(self.num_vpt, self.embed_dim)))
128
+ nn.init.uniform_(getattr(self, f"vpt_{idx}"), -val, val)
129
+ setattr(self, f"vpt_drop_{idx}", nn.Dropout(self.vpt_drop))
130
+
131
+ # Adjust the positional embedding to match the new input size
132
+ self._adjust_pos_embed()
133
+
134
+ in_features, out_features = model.proj.shape
135
+ self.in_features = in_features
136
+ self.out_features = out_features
137
+
138
+ patch_size = self.patch_size[0]
139
+ if patch_size in [16, 32]:
140
+ assert block_size in [8, 16, 32], f"Patch size is 32, but got block size {block_size}."
141
+ else: # patch_size == 14
142
+ assert block_size in [7, 14, 28], f"Patch size is 14, but got block size {block_size}."
143
+
144
+ if norm == "bn":
145
+ norm_layer = nn.BatchNorm2d
146
+ elif norm == "ln":
147
+ norm_layer = nn.LayerNorm
148
+ else:
149
+ norm_layer = _get_norm_layer(model)
150
+
151
+ if act == "relu":
152
+ activation = nn.ReLU(inplace=True)
153
+ elif act == "gelu":
154
+ activation = nn.GELU()
155
+ else:
156
+ activation = _get_activation(model)
157
+
158
+ if block_size == patch_size:
159
+ self.refiner = ConvRefine(
160
+ in_channels=self.in_features,
161
+ out_channels=self.in_features,
162
+ norm_layer=norm_layer,
163
+ activation=activation,
164
+ groups=refiner_groups[self.model_name],
165
+ )
166
+
167
+ elif block_size < patch_size: # upsample
168
+ if block_size == 8 and patch_size == 32:
169
+ self.refiner = nn.Sequential(
170
+ ConvUpsample(
171
+ in_channels=self.in_features,
172
+ out_channels=self.in_features,
173
+ norm_layer=norm_layer,
174
+ activation=activation,
175
+ groups=refiner_groups[self.model_name],
176
+ ),
177
+ ConvUpsample(
178
+ in_channels=self.in_features,
179
+ out_channels=self.in_features,
180
+ norm_layer=norm_layer,
181
+ activation=activation,
182
+ groups=refiner_groups[self.model_name],
183
+ ),
184
+ )
185
+ else:
186
+ self.refiner = ConvUpsample(
187
+ in_channels=self.in_features,
188
+ out_channels=self.in_features,
189
+ norm_layer=norm_layer,
190
+ activation=activation,
191
+ groups=refiner_groups[self.model_name],
192
+ )
193
+
194
+ else: # downsample
195
+ assert block_size // patch_size == 2, f"Block size {block_size} should be 2 times the patch size {patch_size}."
196
+ self.refiner = ConvDownsample(
197
+ in_channels=self.in_features,
198
+ out_channels=self.in_features,
199
+ norm_layer=norm_layer,
200
+ activation=activation,
201
+ groups=refiner_groups[self.model_name],
202
+ )
203
+
204
+ def _adjust_pos_embed(self) -> Tensor:
205
+ """
206
+ Adjust the positional embedding to match the spatial resolution of the feature map.
207
+
208
+ Args:
209
+ orig_h, orig_w: The original spatial resolution of the image.
210
+ new_h, new_w: The new spatial resolution of the image.
211
+ """
212
+ self.positional_embedding = nn.Parameter(self._interpolate_pos_embed(self.pretrain_size[0], self.pretrain_size[1], self.input_size[0], self.input_size[1]), requires_grad=False)
213
+
214
+ def _interpolate_pos_embed(self, orig_h: int, orig_w: int, new_h: int, new_w: int) -> Tensor:
215
+ """
216
+ Interpolate the positional embedding to match the spatial resolution of the feature map.
217
+
218
+ Args:
219
+ orig_h, orig_w: The original spatial resolution of the image.
220
+ new_h, new_w: The new spatial resolution of the image.
221
+ """
222
+ if (orig_h, orig_w) == (new_h, new_w):
223
+ return self.positional_embedding
224
+
225
+ orig_h_patches, orig_w_patches = orig_h // self.patch_size[0], orig_w // self.patch_size[1]
226
+ new_h_patches, new_w_patches = new_h // self.patch_size[0], new_w // self.patch_size[1]
227
+ class_pos_embed, patch_pos_embed = self.positional_embedding[:1, :], self.positional_embedding[1:, :]
228
+ patch_pos_embed = rearrange(patch_pos_embed, "(h w) d -> d h w", h=orig_h_patches, w=orig_w_patches)
229
+ patch_pos_embed = interpolate_pos_embed(patch_pos_embed, size=(new_h_patches, new_w_patches))
230
+ patch_pos_embed = rearrange(patch_pos_embed, "d h w -> (h w) d")
231
+ pos_embed = torch.cat((class_pos_embed, patch_pos_embed), dim=0)
232
+ return pos_embed
233
+
234
+ def train(self, mode: bool = True):
235
+ if mode:
236
+ # training:
237
+ self.conv1.eval()
238
+ self.ln_pre.eval()
239
+ self.resblocks.eval()
240
+ self.ln_post.eval()
241
+
242
+ for idx in range(self.num_layers):
243
+ getattr(self, f"vpt_drop_{idx}").train()
244
+
245
+ self.refiner.train()
246
+
247
+ else:
248
+ # evaluation:
249
+ for module in self.children():
250
+ module.train(mode)
251
+
252
+ def _prepare_vpt(self, layer: int, batch_size: int, device: torch.device) -> Tensor:
253
+ vpt = getattr(self, f"vpt_{layer}").unsqueeze(0).expand(batch_size, -1, -1).to(device) # (batch_size, num_vpt, embed_dim)
254
+ vpt = getattr(self, f"vpt_drop_{layer}")(vpt)
255
+
256
+ return vpt
257
+
258
+ def _forward_patch_embed(self, x: Tensor) -> Tensor:
259
+ # This step performs 1) embed x into patches; 2) append the class token; 3) add positional embeddings.
260
+ assert len(x.shape) == 4, f"Expected input to have shape (batch_size, 3, height, width), but got {x.shape}"
261
+ batch_size, _, height, width = x.shape
262
+
263
+ # Step 1: Embed x into patches
264
+ x = self.conv1(x)
265
+
266
+ # Step 2: Append the class token
267
+ class_embedding = self.class_embedding.expand(batch_size, 1, -1)
268
+ x = rearrange(x, "b d h w -> b (h w) d")
269
+ x = torch.cat([class_embedding, x], dim=1)
270
+
271
+ # Step 3: Add positional embeddings
272
+ pos_embed = self._interpolate_pos_embed(orig_h=self.input_size[0], orig_w=self.input_size[1], new_h=height, new_w=width).expand(batch_size, -1, -1)
273
+ x = x + pos_embed
274
+
275
+ x = self.ln_pre(x)
276
+ return x
277
+
278
+ def _forward_vpt(self, x: Tensor, idx: int) -> Tensor:
279
+ batch_size = x.shape[0]
280
+ device = x.device
281
+
282
+ # Assemble
283
+ vpt = self._prepare_vpt(idx, batch_size, device)
284
+ x = torch.cat([
285
+ x[:, :1, :], # class token
286
+ vpt,
287
+ x[:, 1:, :] # patches
288
+ ], dim=1)
289
+
290
+ # Forward
291
+ x = self.resblocks[idx](x)
292
+
293
+ # Disassemble
294
+ x = torch.cat([
295
+ x[:, :1, :], # class token
296
+ x[:, 1 + self.num_vpt:, :] # patches
297
+ ], dim=1)
298
+
299
+ return x
300
+
301
+ def _forward_adapter(self, x: Tensor, idx: int) -> Tensor:
302
+ return getattr(self, f"adapter{idx}")(x)
303
+
304
+ def forward_encoder(self, x: Tensor) -> Tensor:
305
+ x = self._forward_patch_embed(x)
306
+ for idx in range(self.num_layers):
307
+ x = self._forward_adapter(x, idx) if self.adapter else self._forward_vpt(x, idx)
308
+ x = self.ln_post(x)
309
+ return x
310
+
311
+ def forward(self, x: Tensor) -> Tensor:
312
+ orig_h, orig_w = x.shape[-2:]
313
+ num_patches_h, num_patches_w = orig_h // self.patch_size[0], orig_w // self.patch_size[1]
314
+ x = self.forward_encoder(x)
315
+ x = x[:, 1:, :] # remove the class token
316
+ x = rearrange(x, "b (h w) d -> b d h w", h=num_patches_h, w=num_patches_w)
317
+
318
+ x = self.refiner(x)
319
+ return x
320
+
321
+
322
+ def _vit(
323
+ model_name: str,
324
+ weight_name: str,
325
+ block_size: int = 16,
326
+ num_vpt: int = 32,
327
+ vpt_drop: float = 0.1,
328
+ adapter: bool = False,
329
+ adapter_reduction: int = 4,
330
+ lora: bool = False,
331
+ lora_rank: int = 16,
332
+ lora_alpha: float = 32.0,
333
+ lora_dropout: float = 0.1,
334
+ input_size: Optional[Tuple[int, int]] = None,
335
+ norm: str = "none",
336
+ act: str = "none"
337
+ ) -> ViT:
338
+ assert not (lora and adapter), "LoRA and adapter cannot be used together."
339
+ model = ViT(
340
+ model_name=model_name,
341
+ weight_name=weight_name,
342
+ block_size=block_size,
343
+ num_vpt=num_vpt,
344
+ vpt_drop=vpt_drop,
345
+ adapter=adapter,
346
+ adapter_reduction=adapter_reduction,
347
+ input_size=input_size,
348
+ norm=norm,
349
+ act=act
350
+ )
351
+
352
+ if lora:
353
+ target_modules = []
354
+ for name, module in model.named_modules():
355
+ if isinstance(module, (nn.Linear, nn.Conv2d, nn.MultiheadAttention)) and "refiner" not in name:
356
+ target_modules.append(name)
357
+
358
+ lora_config = LoraConfig(
359
+ r=lora_rank,
360
+ lora_alpha=lora_alpha,
361
+ lora_dropout=lora_dropout,
362
+ bias="none",
363
+ target_modules=target_modules,
364
+ )
365
+ model = get_peft_model(model, lora_config)
366
+
367
+ # Unfreeze refiner
368
+ for name, module in model.named_modules():
369
+ if "refiner" in name:
370
+ module.requires_grad_(True)
371
+
372
+ return model
models/ebc/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .model import EBC, _ebc
2
+
3
+ __all__ = ["EBC", "_ebc"]