haodongli commited on
Commit
d82e7f9
·
1 Parent(s): 3892107
Files changed (44) hide show
  1. da2/__init__.py +25 -0
  2. da2/__pycache__/__init__.cpython-312.pyc +0 -0
  3. da2/model/__init__.py +11 -0
  4. da2/model/__pycache__/__init__.cpython-312.pyc +0 -0
  5. da2/model/__pycache__/base.cpython-312.pyc +0 -0
  6. da2/model/__pycache__/sphere.cpython-312.pyc +0 -0
  7. da2/model/__pycache__/spherevit.cpython-312.pyc +0 -0
  8. da2/model/__pycache__/vit_w_esphere.cpython-312.pyc +0 -0
  9. da2/model/base.py +393 -0
  10. da2/model/dinov2/__init__.py +13 -0
  11. da2/model/dinov2/__pycache__/__init__.cpython-312.pyc +0 -0
  12. da2/model/dinov2/__pycache__/attention.cpython-312.pyc +0 -0
  13. da2/model/dinov2/__pycache__/block.cpython-312.pyc +0 -0
  14. da2/model/dinov2/__pycache__/dinovit.cpython-312.pyc +0 -0
  15. da2/model/dinov2/__pycache__/drop_path.cpython-312.pyc +0 -0
  16. da2/model/dinov2/__pycache__/layer_scale.cpython-312.pyc +0 -0
  17. da2/model/dinov2/__pycache__/mlp.cpython-312.pyc +0 -0
  18. da2/model/dinov2/__pycache__/patch_embed.cpython-312.pyc +0 -0
  19. da2/model/dinov2/__pycache__/swiglu_ffn.cpython-312.pyc +0 -0
  20. da2/model/dinov2/attention.py +79 -0
  21. da2/model/dinov2/block.py +280 -0
  22. da2/model/dinov2/dino_head.py +68 -0
  23. da2/model/dinov2/dinovit.py +223 -0
  24. da2/model/dinov2/drop_path.py +37 -0
  25. da2/model/dinov2/layer_scale.py +28 -0
  26. da2/model/dinov2/mlp.py +41 -0
  27. da2/model/dinov2/patch_embed.py +101 -0
  28. da2/model/dinov2/swiglu_ffn.py +63 -0
  29. da2/model/sphere.py +30 -0
  30. da2/model/spherevit.py +69 -0
  31. da2/model/vit_w_esphere.py +224 -0
  32. da2/utils/__init__.py +11 -0
  33. da2/utils/__pycache__/__init__.cpython-312.pyc +0 -0
  34. da2/utils/__pycache__/base.cpython-312.pyc +0 -0
  35. da2/utils/__pycache__/d2pc.cpython-312.pyc +0 -0
  36. da2/utils/__pycache__/io.cpython-312.pyc +0 -0
  37. da2/utils/__pycache__/model.cpython-312.pyc +0 -0
  38. da2/utils/__pycache__/vis.cpython-312.pyc +0 -0
  39. da2/utils/base.py +56 -0
  40. da2/utils/d2pc.py +76 -0
  41. da2/utils/io.py +63 -0
  42. da2/utils/model.py +15 -0
  43. da2/utils/vis.py +44 -0
  44. requirements.txt +19 -1
da2/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .utils.base import (
2
+ prepare_to_run
3
+ )
4
+ from .utils.model import (
5
+ load_model
6
+ )
7
+ from .utils.io import (
8
+ load_infer_data
9
+ )
10
+ from .utils.vis import (
11
+ colorize_distance,
12
+ concatenate_images
13
+ )
14
+ from .utils.d2pc import (
15
+ distance2pointcloud
16
+ )
17
+
18
+ __all__ = [
19
+ 'prepare_to_run',
20
+ 'load_model',
21
+ 'load_infer_data',
22
+ 'colorize_distance',
23
+ 'concatenate_images',
24
+ 'distance2pointcloud'
25
+ ]
da2/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (498 Bytes). View file
 
da2/model/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .spherevit import (
2
+ SphereViT
3
+ )
4
+ from .vit_w_esphere import (
5
+ ViT_w_Esphere
6
+ )
7
+
8
+ __all__ = [
9
+ 'SphereViT',
10
+ 'ViT_w_Esphere',
11
+ ]
da2/model/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (298 Bytes). View file
 
da2/model/__pycache__/base.cpython-312.pyc ADDED
Binary file (18.4 kB). View file
 
da2/model/__pycache__/sphere.cpython-312.pyc ADDED
Binary file (2.39 kB). View file
 
da2/model/__pycache__/spherevit.cpython-312.pyc ADDED
Binary file (3.96 kB). View file
 
da2/model/__pycache__/vit_w_esphere.cpython-312.pyc ADDED
Binary file (10.5 kB). View file
 
da2/model/base.py ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from math import log2, pi
4
+ from typing import Tuple
5
+ import torch.nn.functional as F
6
+ from einops import rearrange
7
+ from functools import partial
8
+
9
+
10
+ def fourier_dimension_expansion(
11
+ x: torch.Tensor,
12
+ dim: int = 512,
13
+ max_freq: int = 64,
14
+ use_cos: bool = True,
15
+ use_log: bool = True,
16
+ ):
17
+ device, dtype, input_dim = x.device, x.dtype, x.shape[-1]
18
+ # input_dim: 2
19
+ num_bands = dim // (2 * input_dim) if use_cos else dim // input_dim
20
+ # num_bands = 512 // 2 = 256
21
+ if use_log:
22
+ scales = 2.0 ** torch.linspace(
23
+ 0.0, log2(max_freq), steps=num_bands, device=device, dtype=dtype
24
+ )
25
+ else:
26
+ scales = torch.linspace(
27
+ 1.0, max_freq / 2, num_bands, device=device, dtype=dtype
28
+ )
29
+ x = x.unsqueeze(-1)
30
+ scales = scales[(*((None,) * (len(x.shape) - 1)), Ellipsis)]
31
+ x = x * scales * pi
32
+ x = torch.cat(
33
+ (
34
+ [x.sin(), x.cos()]
35
+ if use_cos
36
+ else [
37
+ x.sin(),
38
+ ]
39
+ ),
40
+ dim=-1,
41
+ )
42
+ x = x.flatten(-2)
43
+ return x
44
+
45
+ def flatten(
46
+ flat_tensor: torch.Tensor,
47
+ old: Tuple[int, int],
48
+ new: Tuple[int, int],
49
+ ) -> torch.Tensor:
50
+ if old[0] == new[0] and old[1] == new[1]:
51
+ return flat_tensor
52
+ tensor = flat_tensor.view(flat_tensor.shape[0], old[0], old[1], -1).permute(
53
+ 0, 3, 1, 2
54
+ ) # b c h w
55
+ tensor_interp = F.interpolate(
56
+ tensor,
57
+ size=(new[0], new[1]),
58
+ mode='nearest',
59
+ )
60
+ flat_tensor_interp = tensor_interp.view(
61
+ flat_tensor.shape[0], -1, new[0] * new[1]
62
+ ).permute(
63
+ 0, 2, 1
64
+ ) # b (h w) c
65
+ return flat_tensor_interp.contiguous()
66
+
67
+
68
+ class DimensionAligner(nn.Module):
69
+ def __init__(self, input_dims: list[int], hidden_dim: int):
70
+ super().__init__()
71
+ self.aligners = nn.ModuleList([])
72
+ self.num_chunks = len(input_dims)
73
+ self.checkpoint = True
74
+ for input_dim in input_dims:
75
+ self.aligners.append(nn.Linear(input_dim, hidden_dim))
76
+
77
+ def forward(self, xs: torch.Tensor) -> torch.Tensor:
78
+ outs = [self.aligners[i](x) for i, x in enumerate(xs)]
79
+ return outs
80
+
81
+
82
+ class LayerScale(nn.Module):
83
+ def __init__(
84
+ self,
85
+ dim: int,
86
+ init_values: float | torch.Tensor = 1e-5,
87
+ inplace: bool = False,
88
+ ) -> None:
89
+ super().__init__()
90
+ self.inplace = inplace
91
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
92
+
93
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
94
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
95
+
96
+
97
+ def exists(val):
98
+ return val is not None
99
+
100
+ def default(val, d):
101
+ if exists(val):
102
+ return val
103
+ return d() if callable(d) else d
104
+
105
+
106
+ class SwiGLU(nn.Module):
107
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
108
+ x, gates = x.chunk(2, dim=-1)
109
+ return x * F.silu(gates)
110
+
111
+
112
+ class MLP(nn.Module):
113
+ def __init__(
114
+ self,
115
+ input_dim: int,
116
+ expansion: int = 4,
117
+ dropout: float = 0.0,
118
+ gated: bool = False,
119
+ output_dim: int | None = None,
120
+ ):
121
+ super().__init__()
122
+ if gated:
123
+ expansion = int(expansion * 2 / 3)
124
+ hidden_dim = int(input_dim * expansion)
125
+ output_dim = default(output_dim, input_dim)
126
+ self.norm = nn.LayerNorm(input_dim)
127
+ self.proj1 = nn.Linear(input_dim, hidden_dim)
128
+ self.proj2 = nn.Linear(hidden_dim, output_dim)
129
+ self.act = nn.GELU() if not gated else SwiGLU()
130
+ self.dropout = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity()
131
+
132
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
133
+ x = self.norm(x)
134
+ x = self.proj1(x)
135
+ x = self.act(x)
136
+ x = self.proj2(x)
137
+ x = self.dropout(x)
138
+ return x
139
+
140
+
141
+ class AttentionBlock(nn.Module):
142
+ def __init__(
143
+ self,
144
+ dim: int,
145
+ num_heads: int = 4,
146
+ expansion: int = 4,
147
+ dropout: float = 0.0,
148
+ cosine: bool = False,
149
+ gated: bool = False,
150
+ layer_scale: float = 1.0,
151
+ context_dim: int | None = None,
152
+ detach_query: bool = False,
153
+ residual_ls: bool = False,
154
+ ):
155
+ super().__init__()
156
+ self.dropout = dropout
157
+ self.num_heads = num_heads
158
+ self.hidden_dim = dim
159
+ context_dim = dim if context_dim is None else context_dim
160
+ self.mlp = MLP(dim, expansion=expansion, dropout=dropout, gated=gated)
161
+ self.kv = nn.Linear(context_dim, dim * 2, bias=False)
162
+ self.q = nn.Linear(dim, dim, bias=False)
163
+ self.norm_attnx = nn.LayerNorm(dim)
164
+ self.norm_attnctx = nn.LayerNorm(context_dim)
165
+ self.cosine = cosine
166
+ self.out = nn.Linear(dim, dim, bias=False)
167
+ self.ls1_1 = (
168
+ LayerScale(dim, layer_scale)
169
+ if layer_scale > 0.0 and not residual_ls
170
+ else nn.Identity()
171
+ )
172
+ self.ls1_2 = (
173
+ LayerScale(dim, layer_scale)
174
+ if layer_scale > 0.0 and residual_ls
175
+ else nn.Identity()
176
+ )
177
+ self.ls2 = LayerScale(dim, layer_scale) if layer_scale > 0.0 else nn.Identity()
178
+ self.detach_query = detach_query
179
+
180
+ def attn(
181
+ self,
182
+ x: torch.Tensor,
183
+ attn_bias: torch.Tensor | None = None,
184
+ context: torch.Tensor | None = None,
185
+ pos_embed: torch.Tensor | None = None,
186
+ pos_embed_context: torch.Tensor | None = None,
187
+ rope: nn.Module | None = None,
188
+ rope_pos: torch.Tensor | None = None,
189
+ ) -> torch.Tensor:
190
+ if self.detach_query:
191
+ x = x.detach()
192
+ x = self.norm_attnx(x)
193
+ context = self.norm_attnctx(context)
194
+ k, v = rearrange(
195
+ self.kv(context), 'b n (kv h d) -> b h n d kv', h=self.num_heads, kv=2
196
+ ).unbind(dim=-1)
197
+ q = rearrange(self.q(x), 'b n (h d) -> b h n d', h=self.num_heads)
198
+
199
+ if rope is not None:
200
+ q = rope(q.permute(0, 2, 1, 3), input_pos=rope_pos).permute(0, 2, 1, 3)
201
+ k = rope(k.permute(0, 2, 1, 3), input_pos=rope_pos).permute(0, 2, 1, 3)
202
+ else:
203
+ if pos_embed is not None:
204
+ pos_embed = rearrange(
205
+ pos_embed, 'b n (h d) -> b h n d', h=self.num_heads
206
+ )
207
+ q = q + pos_embed
208
+ if pos_embed_context is not None:
209
+ pos_embed_context = rearrange(
210
+ pos_embed_context, 'b n (h d) -> b h n d', h=self.num_heads
211
+ )
212
+ k = k + pos_embed_context
213
+
214
+ if self.cosine:
215
+ q, k = map(partial(F.normalize, p=2, dim=-1), (q, k)) # cosine sim
216
+
217
+ x = F.scaled_dot_product_attention(
218
+ q, k, v, dropout_p=self.dropout, attn_mask=attn_bias
219
+ )
220
+ x = rearrange(x, 'b h n d -> b n (h d)')
221
+ x = self.out(x)
222
+ return x
223
+
224
+ def forward(
225
+ self,
226
+ x: torch.Tensor,
227
+ context: torch.Tensor | None = None,
228
+ pos_embed: torch.Tensor | None = None,
229
+ pos_embed_context: torch.Tensor | None = None,
230
+ attn_bias: torch.Tensor | None = None,
231
+ rope: nn.Module | None = None,
232
+ rope_pos: torch.Tensor | None = None,
233
+ ) -> torch.Tensor:
234
+ context = x if context is None else context
235
+ x = self.ls1_1(
236
+ self.attn(
237
+ x,
238
+ rope=rope,
239
+ rope_pos=rope_pos,
240
+ attn_bias=attn_bias,
241
+ context=context,
242
+ pos_embed=pos_embed,
243
+ pos_embed_context=pos_embed_context,
244
+ )
245
+ ) + self.ls1_2(x)
246
+ x = self.ls2(self.mlp(x)) + x
247
+ return x
248
+
249
+
250
+ class AttentionSeq(nn.Module):
251
+ def __init__(
252
+ self,
253
+ num_blocks: int,
254
+ dim: int,
255
+ num_heads: int = 4,
256
+ expansion: int = 4,
257
+ dropout: float = 0.0,
258
+ cosine: bool = False,
259
+ gated: bool = False,
260
+ layer_scale: float = 1.0,
261
+ context_dim: int | None = None,
262
+ detach_query: bool = False,
263
+ residual_ls: bool = False,
264
+ ):
265
+ super().__init__()
266
+ self.layers = nn.ModuleList(
267
+ [
268
+ AttentionBlock(
269
+ dim=dim,
270
+ num_heads=num_heads,
271
+ expansion=expansion,
272
+ dropout=dropout,
273
+ cosine=cosine,
274
+ gated=gated,
275
+ layer_scale=layer_scale,
276
+ context_dim=context_dim,
277
+ detach_query=detach_query,
278
+ residual_ls=residual_ls,
279
+ )
280
+ for _ in range(num_blocks)
281
+ ]
282
+ )
283
+
284
+ def forward(
285
+ self,
286
+ x: torch.Tensor,
287
+ context: torch.Tensor | None = None,
288
+ pos_embed: torch.Tensor | None = None,
289
+ pos_embed_context: torch.Tensor | None = None,
290
+ attn_bias: torch.Tensor | None = None,
291
+ rope: nn.Module | None = None,
292
+ rope_pos: torch.Tensor | None = None,
293
+ ) -> torch.Tensor:
294
+ for layer in self.layers:
295
+ x = layer(
296
+ x,
297
+ context=context,
298
+ pos_embed=pos_embed,
299
+ pos_embed_context=pos_embed_context,
300
+ attn_bias=attn_bias,
301
+ rope=rope,
302
+ rope_pos=rope_pos,
303
+ )
304
+ return x
305
+
306
+
307
+ class ResidualConvNet(nn.Module):
308
+ def __init__(
309
+ self,
310
+ dim,
311
+ kernel_size: int = 3,
312
+ padding_mode: str = 'zeros',
313
+ dilation: int = 1,
314
+ layer_scale: float = 1.0,
315
+ use_norm: bool = False,
316
+ ):
317
+ super().__init__()
318
+ self.conv1 = nn.Conv2d(
319
+ dim,
320
+ dim,
321
+ kernel_size=kernel_size,
322
+ padding=dilation * (kernel_size - 1) // 2,
323
+ dilation=dilation,
324
+ padding_mode=padding_mode,
325
+ )
326
+ self.conv2 = nn.Conv2d(
327
+ dim,
328
+ dim,
329
+ kernel_size=kernel_size,
330
+ padding=dilation * (kernel_size - 1) // 2,
331
+ dilation=dilation,
332
+ padding_mode=padding_mode,
333
+ )
334
+ self.activation = nn.LeakyReLU()
335
+ self.gamma = (
336
+ nn.Parameter(layer_scale * torch.ones(1, dim, 1, 1))
337
+ if layer_scale > 0.0
338
+ else 1.0
339
+ )
340
+ self.norm1 = nn.GroupNorm(dim // 16, dim) if use_norm else nn.Identity()
341
+ self.norm2 = nn.GroupNorm(dim // 16, dim) if use_norm else nn.Identity()
342
+
343
+ def forward(self, x):
344
+ out = self.activation(x)
345
+ out = self.conv1(out)
346
+ out = self.norm1(out)
347
+ out = self.activation(out)
348
+ out = self.conv2(out)
349
+ out = self.norm2(out)
350
+ return self.gamma * out + x
351
+
352
+
353
+ class ResidualUpsampler(nn.Module):
354
+ def __init__(
355
+ self,
356
+ hidden_dim,
357
+ output_dim: int = None,
358
+ num_layers: int = 2,
359
+ kernel_size: int = 3,
360
+ layer_scale: float = 1.0,
361
+ padding_mode: str = 'zeros',
362
+ use_norm: bool = False,
363
+ **kwargs,
364
+ ):
365
+ super().__init__()
366
+ output_dim = output_dim if output_dim is not None else hidden_dim // 2
367
+ self.convs = nn.ModuleList([])
368
+ for _ in range(num_layers):
369
+ self.convs.append(
370
+ ResidualConvNet(
371
+ hidden_dim,
372
+ kernel_size=kernel_size,
373
+ layer_scale=layer_scale,
374
+ padding_mode=padding_mode,
375
+ use_norm=use_norm,
376
+ )
377
+ )
378
+ self.up = nn.Sequential(
379
+ nn.Conv2d(
380
+ hidden_dim,
381
+ output_dim,
382
+ kernel_size=1,
383
+ padding=0,
384
+ padding_mode=padding_mode,
385
+ ),
386
+ nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
387
+ )
388
+
389
+ def forward(self, x: torch.Tensor):
390
+ for conv in self.convs:
391
+ x = conv(x)
392
+ x = self.up(x)
393
+ return x
da2/model/dinov2/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .dinovit import (
8
+ DINOViT
9
+ )
10
+
11
+ __all__ = [
12
+ 'DINOViT'
13
+ ]
da2/model/dinov2/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (237 Bytes). View file
 
da2/model/dinov2/__pycache__/attention.cpython-312.pyc ADDED
Binary file (4.13 kB). View file
 
da2/model/dinov2/__pycache__/block.cpython-312.pyc ADDED
Binary file (13.5 kB). View file
 
da2/model/dinov2/__pycache__/dinovit.cpython-312.pyc ADDED
Binary file (9.6 kB). View file
 
da2/model/dinov2/__pycache__/drop_path.cpython-312.pyc ADDED
Binary file (1.63 kB). View file
 
da2/model/dinov2/__pycache__/layer_scale.cpython-312.pyc ADDED
Binary file (1.39 kB). View file
 
da2/model/dinov2/__pycache__/mlp.cpython-312.pyc ADDED
Binary file (1.92 kB). View file
 
da2/model/dinov2/__pycache__/patch_embed.cpython-312.pyc ADDED
Binary file (4.07 kB). View file
 
da2/model/dinov2/__pycache__/swiglu_ffn.cpython-312.pyc ADDED
Binary file (2.81 kB). View file
 
da2/model/dinov2/attention.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
10
+
11
+ import logging
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+
17
+ logger = logging.getLogger("dinov2")
18
+
19
+
20
+ try:
21
+ from xformers.ops import fmha, memory_efficient_attention, unbind
22
+
23
+ XFORMERS_AVAILABLE = True
24
+ except ImportError:
25
+ logger.warning("xFormers not available")
26
+ XFORMERS_AVAILABLE = False
27
+
28
+
29
+ class Attention(nn.Module):
30
+ def __init__(
31
+ self,
32
+ dim: int,
33
+ num_heads: int = 8,
34
+ qkv_bias: bool = False,
35
+ proj_bias: bool = True,
36
+ attn_drop: float = 0.0,
37
+ proj_drop: float = 0.0,
38
+ ) -> None:
39
+ super().__init__()
40
+ self.num_heads = num_heads
41
+ head_dim = dim // num_heads
42
+ self.scale = head_dim**-0.5
43
+
44
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
45
+ self.attn_drop = nn.Dropout(attn_drop) if attn_drop > 0.0 else nn.Identity()
46
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
47
+ self.proj_drop = nn.Dropout(proj_drop) if proj_drop > 0.0 else nn.Identity()
48
+
49
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
50
+ B, N, C = x.shape
51
+ qkv = (
52
+ self.qkv(x)
53
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
54
+ .permute(2, 0, 3, 1, 4)
55
+ )
56
+ x = F.scaled_dot_product_attention(qkv[0], qkv[1], qkv[2])
57
+ x = x.transpose(1, 2).reshape(B, N, C)
58
+ x = self.proj(x)
59
+ x = self.proj_drop(x)
60
+ return x
61
+
62
+
63
+ class MemEffAttention(Attention):
64
+ def forward(self, x: torch.Tensor, attn_bias=None) -> torch.Tensor:
65
+ if not XFORMERS_AVAILABLE or x.device.type == "cpu":
66
+ assert attn_bias is None, "xFormers is required for nested tensors usage"
67
+ return super().forward(x)
68
+
69
+ B, N, C = x.shape
70
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
71
+
72
+ q, k, v = unbind(qkv, 2)
73
+
74
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
75
+ x = x.reshape([B, N, C])
76
+
77
+ x = self.proj(x)
78
+ x = self.proj_drop(x)
79
+ return x
da2/model/dinov2/block.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
10
+
11
+ import logging
12
+ from typing import Any, Callable, Dict, List, Tuple
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+
17
+ from .attention import Attention, MemEffAttention
18
+ from .drop_path import DropPath
19
+ from .layer_scale import LayerScale
20
+ from .mlp import Mlp
21
+
22
+ logger = logging.getLogger("dinov2")
23
+
24
+ try:
25
+ from xformers.ops import fmha, index_select_cat, scaled_index_add
26
+
27
+ XFORMERS_AVAILABLE = True
28
+ except ImportError:
29
+ logger.warning("xFormers not available")
30
+ XFORMERS_AVAILABLE = False
31
+
32
+
33
+ class Block(nn.Module):
34
+ def __init__(
35
+ self,
36
+ dim: int,
37
+ num_heads: int,
38
+ mlp_ratio: float = 4.0,
39
+ qkv_bias: bool = False,
40
+ proj_bias: bool = True,
41
+ ffn_bias: bool = True,
42
+ drop: float = 0.0,
43
+ attn_drop: float = 0.0,
44
+ init_values=None,
45
+ drop_path: float = 0.0,
46
+ act_layer: Callable[..., nn.Module] = nn.GELU,
47
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
48
+ attn_class: Callable[..., nn.Module] = Attention,
49
+ ffn_layer: Callable[..., nn.Module] = Mlp,
50
+ ) -> None:
51
+ super().__init__()
52
+ self.norm1 = norm_layer(dim)
53
+ self.attn = attn_class(
54
+ dim,
55
+ num_heads=num_heads,
56
+ qkv_bias=qkv_bias,
57
+ proj_bias=proj_bias,
58
+ attn_drop=attn_drop,
59
+ proj_drop=drop,
60
+ )
61
+ self.ls1 = (
62
+ LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
63
+ )
64
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
65
+
66
+ self.norm2 = norm_layer(dim)
67
+ mlp_hidden_dim = int(dim * mlp_ratio)
68
+ self.mlp = ffn_layer(
69
+ in_features=dim,
70
+ hidden_features=mlp_hidden_dim,
71
+ act_layer=act_layer,
72
+ drop=drop,
73
+ bias=ffn_bias,
74
+ )
75
+ self.ls2 = (
76
+ LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
77
+ )
78
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
79
+
80
+ self.sample_drop_ratio = drop_path
81
+
82
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
83
+ def attn_residual_func(x: torch.Tensor) -> torch.Tensor:
84
+ return self.ls1(self.attn(self.norm1(x)))
85
+
86
+ def ffn_residual_func(x: torch.Tensor) -> torch.Tensor:
87
+ return self.ls2(self.mlp(self.norm2(x)))
88
+
89
+ if self.training and self.sample_drop_ratio > 0.1:
90
+ # the overhead is compensated only for a drop path rate larger than 0.1
91
+ x = drop_add_residual_stochastic_depth(
92
+ x,
93
+ residual_func=attn_residual_func,
94
+ sample_drop_ratio=self.sample_drop_ratio,
95
+ )
96
+ x = drop_add_residual_stochastic_depth(
97
+ x,
98
+ residual_func=ffn_residual_func,
99
+ sample_drop_ratio=self.sample_drop_ratio,
100
+ )
101
+ elif self.training and self.sample_drop_ratio > 0.0:
102
+ x = x + self.drop_path1(attn_residual_func(x))
103
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
104
+ else:
105
+ x = x + attn_residual_func(x)
106
+ x = x + ffn_residual_func(x)
107
+ return x
108
+
109
+
110
+ def drop_add_residual_stochastic_depth(
111
+ x: torch.Tensor,
112
+ residual_func, #: Callable[[torch.Tensor], torch.Tensor],
113
+ sample_drop_ratio: float = 0.0,
114
+ ) -> torch.Tensor:
115
+ # 1) extract subset using permutation
116
+ b, n, d = x.shape
117
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
118
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
119
+ x_subset = x[brange]
120
+
121
+ # 2) apply residual_func to get residual
122
+ residual = residual_func(x_subset)
123
+
124
+ x_flat = x.flatten(1)
125
+ residual = residual.flatten(1)
126
+
127
+ residual_scale_factor = b / sample_subset_size
128
+
129
+ # 3) add the residual
130
+ x_plus_residual = torch.index_add(
131
+ x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor
132
+ )
133
+ return x_plus_residual.view_as(x)
134
+
135
+
136
+ def get_branges_scales(x, sample_drop_ratio=0.0):
137
+ b, n, d = x.shape
138
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
139
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
140
+ residual_scale_factor = b / sample_subset_size
141
+ return brange, residual_scale_factor
142
+
143
+
144
+ def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
145
+ if scaling_vector is None:
146
+ x_flat = x.flatten(1)
147
+ residual = residual.flatten(1)
148
+ x_plus_residual = torch.index_add(
149
+ x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor
150
+ )
151
+ else:
152
+ x_plus_residual = scaled_index_add(
153
+ x,
154
+ brange,
155
+ residual.to(dtype=x.dtype),
156
+ scaling=scaling_vector,
157
+ alpha=residual_scale_factor,
158
+ )
159
+ return x_plus_residual
160
+
161
+
162
+ attn_bias_cache: Dict[Tuple, Any] = {}
163
+
164
+
165
+ def get_attn_bias_and_cat(x_list, branges=None):
166
+ """
167
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
168
+ """
169
+ batch_sizes = (
170
+ [b.shape[0] for b in branges]
171
+ if branges is not None
172
+ else [x.shape[0] for x in x_list]
173
+ )
174
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
175
+ if all_shapes not in attn_bias_cache.keys():
176
+ seqlens = []
177
+ for b, x in zip(batch_sizes, x_list):
178
+ for _ in range(b):
179
+ seqlens.append(x.shape[1])
180
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
181
+ attn_bias._batch_sizes = batch_sizes
182
+ attn_bias_cache[all_shapes] = attn_bias
183
+
184
+ if branges is not None:
185
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(
186
+ 1, -1, x_list[0].shape[-1]
187
+ )
188
+ else:
189
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
190
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
191
+
192
+ return attn_bias_cache[all_shapes], cat_tensors
193
+
194
+
195
+ def drop_add_residual_stochastic_depth_list(
196
+ x_list: List[torch.Tensor],
197
+ residual_func, #: Callable[[torch.Tensor, Any], torch.Tensor],
198
+ sample_drop_ratio: float = 0.0,
199
+ scaling_vector=None,
200
+ ) -> torch.Tensor:
201
+ # 1) generate random set of indices for dropping samples in the batch
202
+ branges_scales = [
203
+ get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list
204
+ ]
205
+ branges = [s[0] for s in branges_scales]
206
+ residual_scale_factors = [s[1] for s in branges_scales]
207
+
208
+ # 2) get attention bias and index+concat the tensors
209
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
210
+
211
+ # 3) apply residual_func to get residual, and split the result
212
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
213
+
214
+ outputs = []
215
+ for x, brange, residual, residual_scale_factor in zip(
216
+ x_list, branges, residual_list, residual_scale_factors
217
+ ):
218
+ outputs.append(
219
+ add_residual(
220
+ x, brange, residual, residual_scale_factor, scaling_vector
221
+ ).view_as(x)
222
+ )
223
+ return outputs
224
+
225
+
226
+ class NestedTensorBlock(Block):
227
+ def forward_nested(self, x_list: List[torch.Tensor]) -> List[torch.Tensor]:
228
+ """
229
+ x_list contains a list of tensors to nest together and run
230
+ """
231
+ assert isinstance(self.attn, MemEffAttention)
232
+
233
+ if self.training and self.sample_drop_ratio > 0.0:
234
+
235
+ def attn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:
236
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
237
+
238
+ def ffn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:
239
+ return self.mlp(self.norm2(x))
240
+
241
+ x_list = drop_add_residual_stochastic_depth_list(
242
+ x_list,
243
+ residual_func=attn_residual_func,
244
+ sample_drop_ratio=self.sample_drop_ratio,
245
+ scaling_vector=(
246
+ self.ls1.gamma if isinstance(self.ls1, LayerScale) else None
247
+ ),
248
+ )
249
+ x_list = drop_add_residual_stochastic_depth_list(
250
+ x_list,
251
+ residual_func=ffn_residual_func,
252
+ sample_drop_ratio=self.sample_drop_ratio,
253
+ scaling_vector=(
254
+ self.ls2.gamma if isinstance(self.ls1, LayerScale) else None
255
+ ),
256
+ )
257
+ return x_list
258
+ else:
259
+
260
+ def attn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:
261
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
262
+
263
+ def ffn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:
264
+ return self.ls2(self.mlp(self.norm2(x)))
265
+
266
+ attn_bias, x = get_attn_bias_and_cat(x_list)
267
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
268
+ x = x + ffn_residual_func(x)
269
+ return attn_bias.split(x)
270
+
271
+ def forward(self, x_or_x_list):
272
+ if isinstance(x_or_x_list, torch.Tensor):
273
+ return super().forward(x_or_x_list)
274
+ elif isinstance(x_or_x_list, list):
275
+ assert (
276
+ XFORMERS_AVAILABLE
277
+ ), "Please install xFormers for nested tensors usage"
278
+ return self.forward_nested(x_or_x_list)
279
+ else:
280
+ raise AssertionError
da2/model/dinov2/dino_head.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch.nn.init import trunc_normal_
10
+ from torch.nn.utils import weight_norm
11
+
12
+
13
+ class DINOHead(nn.Module):
14
+ def __init__(
15
+ self,
16
+ in_dim,
17
+ out_dim,
18
+ use_bn=False,
19
+ nlayers=3,
20
+ hidden_dim=2048,
21
+ bottleneck_dim=256,
22
+ mlp_bias=True,
23
+ ):
24
+ super().__init__()
25
+ nlayers = max(nlayers, 1)
26
+ self.mlp = _build_mlp(
27
+ nlayers,
28
+ in_dim,
29
+ bottleneck_dim,
30
+ hidden_dim=hidden_dim,
31
+ use_bn=use_bn,
32
+ bias=mlp_bias,
33
+ )
34
+ self.apply(self._init_weights)
35
+ self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
36
+ self.last_layer.weight_g.data.fill_(1)
37
+
38
+ def _init_weights(self, m):
39
+ if isinstance(m, nn.Linear):
40
+ trunc_normal_(m.weight, std=0.02)
41
+ if isinstance(m, nn.Linear) and m.bias is not None:
42
+ nn.init.constant_(m.bias, 0)
43
+
44
+ def forward(self, x):
45
+ x = self.mlp(x)
46
+ eps = 1e-6 if x.dtype == torch.float16 else 1e-12
47
+ x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
48
+ x = self.last_layer(x)
49
+ return x
50
+
51
+
52
+ def _build_mlp(
53
+ nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True
54
+ ):
55
+ if nlayers == 1:
56
+ return nn.Linear(in_dim, bottleneck_dim, bias=bias)
57
+ else:
58
+ layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
59
+ if use_bn:
60
+ layers.append(nn.BatchNorm1d(hidden_dim))
61
+ layers.append(nn.GELU())
62
+ for _ in range(nlayers - 2):
63
+ layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
64
+ if use_bn:
65
+ layers.append(nn.BatchNorm1d(hidden_dim))
66
+ layers.append(nn.GELU())
67
+ layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
68
+ return nn.Sequential(*layers)
da2/model/dinov2/dinovit.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
10
+
11
+ import logging
12
+
13
+ import math
14
+ import torch
15
+ import torch.nn as nn
16
+ import contextlib
17
+ from functools import partial
18
+ from typing import Sequence
19
+ from .block import (
20
+ Block
21
+ )
22
+ from .attention import (
23
+ MemEffAttention
24
+ )
25
+ from .mlp import (
26
+ Mlp
27
+ )
28
+ from .patch_embed import (
29
+ PatchEmbed
30
+ )
31
+ from .swiglu_ffn import (
32
+ SwiGLUFFNFused
33
+ )
34
+
35
+ logger = logging.getLogger("dinov2")
36
+
37
+
38
+ try:
39
+ from xformers.ops import fmha, memory_efficient_attention, unbind
40
+
41
+ XFORMERS_AVAILABLE = True
42
+ except ImportError:
43
+ logger.warning("xFormers not available")
44
+ XFORMERS_AVAILABLE = False
45
+
46
+
47
+ class DINOViT(nn.Module):
48
+ def __init__(
49
+ self,
50
+ img_size=518,
51
+ patch_size=14,
52
+ in_chans=3,
53
+ embed_dim=1024,
54
+ depth=24,
55
+ num_heads=16,
56
+ mlp_ratio=4,
57
+ qkv_bias=True,
58
+ ffn_bias=True,
59
+ proj_bias=True,
60
+ drop_path_rate=0.0,
61
+ drop_path_uniform=False,
62
+ init_values=1.0,
63
+ embed_layer=PatchEmbed,
64
+ act_layer=nn.GELU,
65
+ block_fn=partial(Block, attn_class=MemEffAttention),
66
+ ffn_layer="mlp",
67
+ block_chunks=0,
68
+ output_idx=[6, 12, 18, 24],
69
+ num_register_tokens=0,
70
+ interpolate_antialias=False,
71
+ use_norm=True,
72
+ frozen_stages=0,
73
+ ):
74
+ """
75
+ Args:
76
+ img_size (int, tuple): input image size
77
+ patch_size (int, tuple): patch size
78
+ in_chans (int): number of input channels
79
+ embed_dim (int): embedding dimension
80
+ depth (int): depth of transformer
81
+ num_heads (int): number of attention heads
82
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
83
+ qkv_bias (bool): enable bias for qkv if True
84
+ proj_bias (bool): enable bias for proj in attn if True
85
+ ffn_bias (bool): enable bias for ffn if True
86
+ drop_path_rate (float): stochastic depth rate
87
+ drop_path_uniform (bool): apply uniform drop rate across blocks
88
+ weight_init (str): weight init scheme
89
+ init_values (float): layer-scale init values
90
+ embed_layer (nn.Module): patch embedding layer
91
+ act_layer (nn.Module): MLP activation layer
92
+ block_fn (nn.Module): transformer block class
93
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
94
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
95
+ """
96
+ super().__init__()
97
+ self.frozen_stages = frozen_stages
98
+ self.patch_size = patch_size
99
+ self.output_idx = output_idx
100
+ self.num_register_tokens = num_register_tokens
101
+ self.interpolate_antialias = interpolate_antialias
102
+
103
+ self.patch_embed = PatchEmbed(
104
+ img_size=img_size,
105
+ patch_size=patch_size,
106
+ in_chans=in_chans,
107
+ embed_dim=embed_dim,
108
+ )
109
+ num_patches = self.patch_embed.num_patches
110
+
111
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
112
+ self.pos_embed = nn.Parameter(
113
+ torch.zeros(1, num_patches + 1, embed_dim)
114
+ )
115
+ assert num_register_tokens >= 0
116
+ self.register_tokens = nn.Parameter(
117
+ torch.zeros(1, max(1, num_register_tokens), embed_dim)
118
+ )
119
+
120
+ if drop_path_uniform is True:
121
+ dpr = [drop_path_rate] * depth
122
+ else:
123
+ dpr = [
124
+ x.item() for x in torch.linspace(0, drop_path_rate, depth)
125
+ ]
126
+
127
+ if ffn_layer == "mlp":
128
+ ffn_layer = Mlp
129
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
130
+ ffn_layer = SwiGLUFFNFused
131
+ elif ffn_layer == "identity":
132
+ def f():
133
+ return nn.Identity()
134
+ ffn_layer = f
135
+ else:
136
+ raise NotImplementedError
137
+
138
+ blocks_list = [
139
+ block_fn(
140
+ dim=embed_dim,
141
+ num_heads=num_heads,
142
+ mlp_ratio=mlp_ratio,
143
+ qkv_bias=qkv_bias,
144
+ proj_bias=proj_bias,
145
+ ffn_bias=ffn_bias,
146
+ drop_path=dpr[i],
147
+ norm_layer=nn.LayerNorm,
148
+ act_layer=act_layer,
149
+ ffn_layer=ffn_layer,
150
+ init_values=init_values,
151
+ )
152
+ for i in range(depth)
153
+ ]
154
+ self.chunked_blocks = False
155
+ self.blocks = nn.ModuleList(blocks_list)
156
+
157
+ self.norm = nn.LayerNorm(embed_dim)
158
+ self.use_norm = use_norm
159
+ self.head = nn.Identity()
160
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
161
+
162
+ def interpolate_pos_encoding(self, x, W, H):
163
+ previous_dtype = x.dtype
164
+ N = self.pos_embed.shape[1] - 1
165
+ pos_embed = self.pos_embed.float()
166
+ class_pos_embed = pos_embed[:, 0]
167
+ patch_pos_embed = pos_embed[:, 1:]
168
+ dim = x.shape[-1]
169
+ w0 = W // self.patch_size
170
+ h0 = H // self.patch_size
171
+
172
+ M = int(math.sqrt(N))
173
+ assert N == M * M
174
+ kwargs = {}
175
+ kwargs["size"] = (w0, h0)
176
+
177
+ patch_pos_embed = nn.functional.interpolate(
178
+ patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
179
+ mode="bicubic",
180
+ antialias=self.interpolate_antialias,
181
+ **kwargs,
182
+ )
183
+ assert (w0, h0) == patch_pos_embed.shape[-2:]
184
+
185
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
186
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(
187
+ previous_dtype
188
+ )
189
+
190
+ def tokenize(self, x):
191
+ _, _, W, H = x.shape
192
+ with torch.no_grad() if self.frozen_stages > -1 else contextlib.nullcontext():
193
+ x = self.patch_embed(x)
194
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
195
+ dino_pos_embed = self.interpolate_pos_encoding(x, W, H)
196
+ x = x + dino_pos_embed
197
+ return x
198
+
199
+ def forward_features(self, x):
200
+ shapes = [val // self.patch_size for val in x.shape[-2:]]
201
+ batch_size = x.shape[0]
202
+ features = []
203
+ x = self.tokenize(x)
204
+ for i, blk in enumerate(self.blocks):
205
+ with (
206
+ torch.no_grad() if i < self.frozen_stages else contextlib.nullcontext()
207
+ ):
208
+ x = blk(x)
209
+ features.append(x)
210
+ if self.use_norm:
211
+ with (
212
+ torch.no_grad()
213
+ if self.frozen_stages >= len(self.blocks)
214
+ else contextlib.nullcontext()
215
+ ):
216
+ features = [self.norm(out) for out in features]
217
+ features = [out[:, self.num_register_tokens + 1 :] for out in features]
218
+ features = [out.reshape(batch_size, *shapes, -1) for out in features]
219
+ return features
220
+
221
+ def forward(self, *args):
222
+ features = self.forward_features(*args)
223
+ return features
da2/model/dinov2/drop_path.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
10
+
11
+
12
+ import torch.nn as nn
13
+
14
+
15
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
16
+ if drop_prob == 0.0 or not training:
17
+ return x
18
+ keep_prob = 1 - drop_prob
19
+ shape = (x.shape[0],) + (1,) * (
20
+ x.ndim - 1
21
+ ) # work with diff dim tensors, not just 2D ConvNets
22
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
23
+ if keep_prob > 0.0:
24
+ random_tensor.div_(keep_prob)
25
+ output = x * random_tensor
26
+ return output
27
+
28
+
29
+ class DropPath(nn.Module):
30
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
31
+
32
+ def __init__(self, drop_prob=None):
33
+ super(DropPath, self).__init__()
34
+ self.drop_prob = drop_prob
35
+
36
+ def forward(self, x):
37
+ return drop_path(x, self.drop_prob, self.training)
da2/model/dinov2/layer_scale.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
8
+
9
+ from typing import Union
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ from torch import Tensor
14
+
15
+
16
+ class LayerScale(nn.Module):
17
+ def __init__(
18
+ self,
19
+ dim: int,
20
+ init_values: Union[float, Tensor] = 1e-5,
21
+ inplace: bool = False,
22
+ ) -> None:
23
+ super().__init__()
24
+ self.inplace = inplace
25
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
26
+
27
+ def forward(self, x: Tensor) -> Tensor:
28
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
da2/model/dinov2/mlp.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
10
+
11
+
12
+ from typing import Callable, Optional
13
+
14
+ from torch import Tensor, nn
15
+
16
+
17
+ class Mlp(nn.Module):
18
+ def __init__(
19
+ self,
20
+ in_features: int,
21
+ hidden_features: Optional[int] = None,
22
+ out_features: Optional[int] = None,
23
+ act_layer: Callable[..., nn.Module] = nn.GELU,
24
+ drop: float = 0.0,
25
+ bias: bool = True,
26
+ ) -> None:
27
+ super().__init__()
28
+ out_features = out_features or in_features
29
+ hidden_features = hidden_features or in_features
30
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
31
+ self.act = act_layer()
32
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
33
+ self.drop = nn.Dropout(drop) if drop > 0.0 else nn.Identity()
34
+
35
+ def forward(self, x: Tensor) -> Tensor:
36
+ x = self.fc1(x)
37
+ x = self.act(x)
38
+ x = self.drop(x)
39
+ x = self.fc2(x)
40
+ x = self.drop(x)
41
+ return x
da2/model/dinov2/patch_embed.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
10
+
11
+ from typing import Callable, Optional, Tuple, Union
12
+
13
+ import torch.nn as nn
14
+ from torch import Tensor
15
+
16
+
17
+ def make_2tuple(x):
18
+ if isinstance(x, tuple):
19
+ assert len(x) == 2
20
+ return x
21
+
22
+ assert isinstance(x, int)
23
+ return (x, x)
24
+
25
+
26
+ class PatchEmbed(nn.Module):
27
+ """
28
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
29
+
30
+ Args:
31
+ img_size: Image size.
32
+ patch_size: Patch token size.
33
+ in_chans: Number of input image channels.
34
+ embed_dim: Number of linear projection output channels.
35
+ norm_layer: Normalization layer.
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ img_size: Union[int, Tuple[int, int]] = 224,
41
+ patch_size: Union[int, Tuple[int, int]] = 16,
42
+ in_chans: int = 3,
43
+ embed_dim: int = 768,
44
+ norm_layer: Optional[Callable] = None,
45
+ flatten_embedding: bool = True,
46
+ ) -> None:
47
+ super().__init__()
48
+
49
+ image_HW = make_2tuple(img_size)
50
+ patch_HW = make_2tuple(patch_size)
51
+ patch_grid_size = (
52
+ image_HW[0] // patch_HW[0],
53
+ image_HW[1] // patch_HW[1],
54
+ )
55
+
56
+ self.img_size = image_HW
57
+ self.patch_size = patch_HW
58
+ self.patches_resolution = patch_grid_size
59
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
60
+
61
+ self.in_chans = in_chans
62
+ self.embed_dim = embed_dim
63
+
64
+ self.flatten_embedding = flatten_embedding
65
+
66
+ self.proj = nn.Conv2d(
67
+ in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW
68
+ )
69
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
70
+
71
+ def forward(self, x: Tensor) -> Tensor:
72
+ _, _, H, W = x.shape
73
+ patch_H, patch_W = self.patch_size
74
+
75
+ assert (
76
+ H % patch_H == 0
77
+ ), f"Input image height {H} is not a multiple of patch height {patch_H}"
78
+ assert (
79
+ W % patch_W == 0
80
+ ), f"Input image width {W} is not a multiple of patch width: {patch_W}"
81
+
82
+ x = self.proj(x) # B C H W
83
+ H, W = x.size(2), x.size(3)
84
+ x = x.flatten(2).transpose(1, 2) # B HW C
85
+ x = self.norm(x)
86
+ if not self.flatten_embedding:
87
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
88
+ return x
89
+
90
+ def flops(self) -> float:
91
+ Ho, Wo = self.patches_resolution
92
+ flops = (
93
+ Ho
94
+ * Wo
95
+ * self.embed_dim
96
+ * self.in_chans
97
+ * (self.patch_size[0] * self.patch_size[1])
98
+ )
99
+ if self.norm is not None:
100
+ flops += Ho * Wo * self.embed_dim
101
+ return flops
da2/model/dinov2/swiglu_ffn.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Callable, Optional
8
+
9
+ import torch.nn.functional as F
10
+ from torch import Tensor, nn
11
+
12
+
13
+ class SwiGLUFFN(nn.Module):
14
+ def __init__(
15
+ self,
16
+ in_features: int,
17
+ hidden_features: Optional[int] = None,
18
+ out_features: Optional[int] = None,
19
+ act_layer: Callable[..., nn.Module] = None,
20
+ drop: float = 0.0,
21
+ bias: bool = True,
22
+ ) -> None:
23
+ super().__init__()
24
+ out_features = out_features or in_features
25
+ hidden_features = hidden_features or in_features
26
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
27
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
28
+
29
+ def forward(self, x: Tensor) -> Tensor:
30
+ x12 = self.w12(x)
31
+ x1, x2 = x12.chunk(2, dim=-1)
32
+ hidden = F.silu(x1) * x2
33
+ return self.w3(hidden)
34
+
35
+
36
+ try:
37
+ from xformers.ops import SwiGLU
38
+
39
+ XFORMERS_AVAILABLE = True
40
+ except ImportError:
41
+ SwiGLU = SwiGLUFFN
42
+ XFORMERS_AVAILABLE = False
43
+
44
+
45
+ class SwiGLUFFNFused(SwiGLU):
46
+ def __init__(
47
+ self,
48
+ in_features: int,
49
+ hidden_features: Optional[int] = None,
50
+ out_features: Optional[int] = None,
51
+ act_layer: Callable[..., nn.Module] = None,
52
+ drop: float = 0.0,
53
+ bias: bool = True,
54
+ ) -> None:
55
+ out_features = out_features or in_features
56
+ hidden_features = hidden_features or in_features
57
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
58
+ super().__init__(
59
+ in_features=in_features,
60
+ hidden_features=hidden_features,
61
+ out_features=out_features,
62
+ bias=bias,
63
+ )
da2/model/sphere.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def get_uv_gird(h, w, device):
5
+ pixel_coords_x = torch.linspace(0.5, w - 0.5, w, device=device)
6
+ pixel_coords_y = torch.linspace(0.5, h - 0.5, h, device=device)
7
+ stacks = [pixel_coords_x.repeat(h, 1), pixel_coords_y.repeat(w, 1).t()]
8
+ grid = torch.stack(stacks, dim=0).float()
9
+ grid = grid.to(device).unsqueeze(0)
10
+ return grid
11
+
12
+
13
+ class Sphere():
14
+ def __init__(self, config, device):
15
+ self.config = config
16
+ self.device = device
17
+
18
+ def get_directions(self, shape):
19
+ h, w = shape
20
+ uv = get_uv_gird(h, w, device=self.device)
21
+ u, v = uv.unbind(dim=1)
22
+ width, height = self.config['width'], self.config['height']
23
+ hfov, vfov = self.config['hfov'], self.config['vfov']
24
+ longitude = (u - width / 2) / width * hfov
25
+ latitude = (v - height / 2) / height * vfov
26
+ x = torch.cos(latitude) * torch.sin(longitude)
27
+ z = torch.cos(latitude) * torch.cos(longitude)
28
+ y = torch.sin(latitude)
29
+ sphere_directions = torch.stack([x, y, z], dim=1)
30
+ return sphere_directions
da2/model/spherevit.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from copy import deepcopy
5
+ from math import (
6
+ ceil,
7
+ sqrt
8
+ )
9
+ from huggingface_hub import PyTorchModelHubMixin
10
+ import torchvision.transforms.v2.functional as TF
11
+ from .dinov2 import DINOViT
12
+ from .vit_w_esphere import ViT_w_Esphere
13
+ from .sphere import Sphere
14
+
15
+
16
+ IMAGENET_DATASET_MEAN = (0.485, 0.456, 0.406)
17
+ IMAGENET_DATASET_STD = (0.229, 0.224, 0.225)
18
+
19
+ class SphereViT(nn.Module, PyTorchModelHubMixin):
20
+ def __init__(self, config):
21
+ super().__init__()
22
+ self.config = config
23
+ self.dino = DINOViT()
24
+ self.vit_w_esphere = ViT_w_Esphere(config['spherevit']['vit_w_esphere'])
25
+ feature_slices = self.dino.output_idx
26
+ self.feature_slices = list(
27
+ zip([0, *feature_slices[:-1]], feature_slices)
28
+ )
29
+ self.device = None
30
+
31
+ def to(self, *args):
32
+ self.device = args[0]
33
+ return super().to(*args)
34
+
35
+ def forward(self, images):
36
+ B, _, H, W = images.shape
37
+ current_pixels = H * W
38
+ target_pixels = min(self.config['inference']['max_pixels'],
39
+ max(self.config['inference']['min_pixels'], current_pixels))
40
+ factor = sqrt(target_pixels / current_pixels)
41
+ sphere_config = deepcopy(self.config['spherevit']['sphere'])
42
+ sphere_config['width'] *= factor
43
+ sphere_config['height'] *= factor
44
+ sphere = Sphere(config=sphere_config, device=self.device)
45
+ H_new = int(H * factor)
46
+ W_new = int(W * factor)
47
+ DINO_patch_size = 14 # please see the line 51 of `src/da2/model/dinov2/dinovit.py` (I know it's a little ugly to hardcode it here T_T)
48
+ H_new = ceil(H_new / DINO_patch_size) * DINO_patch_size
49
+ W_new = ceil(W_new / DINO_patch_size) * DINO_patch_size
50
+ images = F.interpolate(images, size=(H_new, W_new), mode='bilinear', align_corners=False)
51
+ images = TF.normalize(
52
+ images.float(),
53
+ mean=IMAGENET_DATASET_MEAN,
54
+ std=IMAGENET_DATASET_STD,
55
+ )
56
+
57
+ sphere_dirs = sphere.get_directions(shape=(H_new, W_new))
58
+ sphere_dirs = sphere_dirs.to(self.device)
59
+ sphere_dirs = sphere_dirs.repeat(B, 1, 1, 1)
60
+
61
+ features = self.dino(images)
62
+ features = [
63
+ features[i:j][-1].contiguous()
64
+ for i, j in self.feature_slices
65
+ ]
66
+ distance = self.vit_w_esphere(images, features, sphere_dirs)
67
+ distance = F.interpolate(distance, size=(H, W), mode='bilinear', align_corners=False)
68
+ distance = distance.squeeze(dim=1) # (b, 1, h, w) -> (b, h, w)
69
+ return distance
da2/model/vit_w_esphere.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from einops import rearrange
5
+ from .base import (
6
+ fourier_dimension_expansion,
7
+ flatten,
8
+ DimensionAligner,
9
+ AttentionSeq,
10
+ ResidualUpsampler
11
+ )
12
+
13
+
14
+ class _ViT_w_Esphere(nn.Module):
15
+ def __init__(
16
+ self,
17
+ hidden_dim: int,
18
+ num_heads: int = 8,
19
+ expansion: int = 4,
20
+ num_layers_head: int | list[int] = 4,
21
+ dropout: float = 0.0,
22
+ kernel_size: int = 7,
23
+ layer_scale: float = 1.0,
24
+ out_dim: int = 1,
25
+ num_prompt_blocks: int = 1,
26
+ use_norm: bool = False,
27
+ **kwargs,
28
+ ) -> None:
29
+ super().__init__()
30
+ self.out_dim = out_dim
31
+ self.hidden_dim = hidden_dim
32
+ self.up_sampler = nn.ModuleList([])
33
+ self.pred_head = nn.ModuleList([])
34
+ self.process_features = nn.ModuleList([])
35
+ self.prompt_camera = nn.ModuleList([])
36
+ mult = 2
37
+ self.to_latents = nn.Linear(hidden_dim, hidden_dim)
38
+
39
+ for _ in range(4):
40
+ self.prompt_camera.append(
41
+ AttentionSeq(
42
+ num_blocks=num_prompt_blocks,
43
+ dim=hidden_dim,
44
+ num_heads=num_heads,
45
+ expansion=expansion,
46
+ dropout=dropout,
47
+ layer_scale=-1.0,
48
+ context_dim=hidden_dim,
49
+ )
50
+ )
51
+
52
+ for i, depth in enumerate(num_layers_head):
53
+ current_dim = min(hidden_dim, mult * hidden_dim // int(2**i))
54
+ next_dim = mult * hidden_dim // int(2 ** (i + 1))
55
+ output_dim = max(next_dim, out_dim)
56
+ self.process_features.append(
57
+ nn.ConvTranspose2d(
58
+ hidden_dim,
59
+ current_dim,
60
+ kernel_size=max(1, 2 * i),
61
+ stride=max(1, 2 * i),
62
+ padding=0,
63
+ )
64
+ )
65
+ self.up_sampler.append(
66
+ ResidualUpsampler(
67
+ current_dim,
68
+ output_dim=output_dim,
69
+ expansion=expansion,
70
+ layer_scale=layer_scale,
71
+ kernel_size=kernel_size,
72
+ num_layers=depth,
73
+ use_norm=use_norm,
74
+ )
75
+ )
76
+ pred_head = (
77
+ nn.Sequential(nn.LayerNorm(next_dim), nn.Linear(next_dim, output_dim))
78
+ if i == len(num_layers_head) - 1
79
+ else nn.Identity()
80
+ )
81
+ self.pred_head.append(pred_head)
82
+
83
+ self.to_depth_lr = nn.Conv2d(
84
+ output_dim,
85
+ output_dim // 2,
86
+ kernel_size=3,
87
+ padding=1,
88
+ padding_mode='reflect',
89
+ )
90
+ self.to_confidence_lr = nn.Conv2d(
91
+ output_dim,
92
+ output_dim // 2,
93
+ kernel_size=3,
94
+ padding=1,
95
+ padding_mode='reflect',
96
+ )
97
+ self.to_depth_hr = nn.Sequential(
98
+ nn.Conv2d(
99
+ output_dim // 2, 32, kernel_size=3, padding=1, padding_mode='reflect'
100
+ ),
101
+ nn.LeakyReLU(),
102
+ nn.Conv2d(32, 1, kernel_size=1),
103
+ )
104
+ self.to_confidence_hr = nn.Sequential(
105
+ nn.Conv2d(
106
+ output_dim // 2, 32, kernel_size=3, padding=1, padding_mode='reflect'
107
+ ),
108
+ nn.LeakyReLU(),
109
+ nn.Conv2d(32, 1, kernel_size=1),
110
+ )
111
+
112
+ def set_original_shapes(self, shapes: tuple[int, int]):
113
+ self.original_shapes = shapes
114
+
115
+ def set_shapes(self, shapes: tuple[int, int]):
116
+ self.shapes = shapes
117
+
118
+ def embed_sphere_dirs(self, sphere_dirs):
119
+ sphere_embedding = flatten(
120
+ sphere_dirs, old=self.original_shapes, new=self.shapes
121
+ )
122
+ # index 0 -> Y
123
+ # index 1 -> Z
124
+ # index 2 -> X
125
+ r1, r2, r3 = sphere_embedding[..., 0], sphere_embedding[..., 1], sphere_embedding[..., 2]
126
+ polar = torch.asin(r2)
127
+ r3_clipped = r3.abs().clip(min=1e-5) * (2 * (r3 >= 0).int() - 1)
128
+ azimuth = torch.atan2(r1, r3_clipped)
129
+ # [polar, azimuth] is the angle field
130
+ sphere_embedding = torch.stack([polar, azimuth], dim=-1)
131
+ # expand the dimension of the angle field to image feature dimensions, via sine-cosine basis embedding
132
+ sphere_embedding = fourier_dimension_expansion(
133
+ sphere_embedding,
134
+ dim=self.hidden_dim,
135
+ max_freq=max(self.shapes) // 2,
136
+ use_cos=False,
137
+ )
138
+ return sphere_embedding
139
+
140
+ def condition(self, feat, sphere_embeddings):
141
+ conditioned_features = [
142
+ prompter(rearrange(feature, 'b h w c -> b (h w) c'), sphere_embeddings)
143
+ for prompter, feature in zip(self.prompt_camera, feat)
144
+ ]
145
+ return conditioned_features
146
+
147
+ def process(self, features_list, sphere_embeddings):
148
+ conditioned_features = self.condition(features_list, sphere_embeddings)
149
+ init_latents = self.to_latents(conditioned_features[0])
150
+ init_latents = rearrange(
151
+ init_latents, 'b (h w) c -> b c h w', h=self.shapes[0], w=self.shapes[1]
152
+ ).contiguous()
153
+ conditioned_features = [
154
+ rearrange(
155
+ x, 'b (h w) c -> b c h w', h=self.shapes[0], w=self.shapes[1]
156
+ ).contiguous()
157
+ for x in conditioned_features
158
+ ]
159
+ latents = init_latents
160
+
161
+ out_features = []
162
+ # Pyramid-like multi-layer convolutional feature extraction
163
+ for i, up in enumerate(self.up_sampler):
164
+ latents = latents + self.process_features[i](conditioned_features[i + 1])
165
+ latents = up(latents)
166
+ out_features.append(latents)
167
+ return out_features
168
+
169
+ def prediction_head(self, out_features):
170
+ depths = []
171
+ h_out, w_out = out_features[-1].shape[-2:]
172
+ for i, (layer, features) in enumerate(zip(self.pred_head, out_features)):
173
+ out_depth_features = layer(features.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
174
+ if i < len(self.pred_head) - 1:
175
+ continue
176
+ depths.append(out_depth_features)
177
+ out_depth_features = F.interpolate(
178
+ out_depth_features, size=(h_out, w_out), mode='bilinear', align_corners=True
179
+ )
180
+ distance = self.to_depth_lr(out_depth_features)
181
+ distance = F.interpolate(
182
+ distance, size=self.original_shapes, mode='bilinear', align_corners=True
183
+ )
184
+ distance = self.to_depth_hr(distance)
185
+ return distance
186
+
187
+ def forward(
188
+ self,
189
+ features: list[torch.Tensor],
190
+ sphere_dirs: torch.Tensor
191
+ ) -> torch.Tensor:
192
+ sphere_embeddings = self.embed_sphere_dirs(sphere_dirs)
193
+ features = self.process(features, sphere_embeddings)
194
+ distance = self.prediction_head(features)
195
+ return distance
196
+
197
+
198
+ class ViT_w_Esphere(nn.Module):
199
+ def __init__(self, config):
200
+ super().__init__()
201
+ self.config = config
202
+ self.dim_aligner = DimensionAligner(
203
+ input_dims=config['input_dims'],
204
+ hidden_dim=config['hidden_dim'],
205
+ )
206
+ self._vit_w_esphere = _ViT_w_Esphere(**config)
207
+
208
+ def forward(self, images, features, sphere_dirs) -> torch.Tensor:
209
+ _, _, H, W = images.shape
210
+ sphere_dirs = sphere_dirs
211
+ common_shape = features[0].shape[1:3]
212
+ features = self.dim_aligner(features)
213
+ sphere_dirs = rearrange(sphere_dirs, 'b c h w -> b (h w) c')
214
+
215
+ self._vit_w_esphere.set_shapes(common_shape)
216
+ self._vit_w_esphere.set_original_shapes((H, W))
217
+ logdistance = self._vit_w_esphere(
218
+ features=features,
219
+ sphere_dirs=sphere_dirs,
220
+ )
221
+
222
+ distance = torch.exp(logdistance.clip(min=-8.0, max=8.0) + 2.0)
223
+ distance = distance / torch.quantile(distance, 0.98)
224
+ return distance
da2/utils/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base import (
2
+ prepare_to_run
3
+ )
4
+ from .model import (
5
+ load_model
6
+ )
7
+
8
+ __all__ = [
9
+ 'prepare_to_run',
10
+ 'load_model'
11
+ ]
da2/utils/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (287 Bytes). View file
 
da2/utils/__pycache__/base.cpython-312.pyc ADDED
Binary file (3.22 kB). View file
 
da2/utils/__pycache__/d2pc.cpython-312.pyc ADDED
Binary file (6.21 kB). View file
 
da2/utils/__pycache__/io.cpython-312.pyc ADDED
Binary file (3.62 kB). View file
 
da2/utils/__pycache__/model.cpython-312.pyc ADDED
Binary file (1.23 kB). View file
 
da2/utils/__pycache__/vis.cpython-312.pyc ADDED
Binary file (2.42 kB). View file
 
da2/utils/base.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import argparse
3
+ import os
4
+ from accelerate import Accelerator
5
+ from accelerate.logging import get_logger
6
+ from accelerate.utils import (
7
+ InitProcessGroupKwargs,
8
+ ProjectConfiguration,
9
+ set_seed
10
+ )
11
+ import logging
12
+ from datetime import (
13
+ timedelta,
14
+ datetime
15
+ )
16
+
17
+
18
+ def load_config(config_path):
19
+ with open(config_path, 'r') as f:
20
+ config = json.load(f)
21
+ return config
22
+
23
+ def arg_parser():
24
+ parser = argparse.ArgumentParser()
25
+ parser.add_argument('--config_path', type=str, required=True)
26
+ args = parser.parse_args()
27
+ return args
28
+
29
+ def prepare_to_run():
30
+ args = arg_parser()
31
+ logging.basicConfig(
32
+ format='%(asctime)s --> %(message)s',
33
+ datefmt='%m/%d %H:%M:%S',
34
+ level=logging.INFO,
35
+ )
36
+ config = load_config(args.config_path)
37
+ kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=config['accelerator']['timeout']))
38
+ version = os.path.basename(args.config_path)[:-5]
39
+ output_dir = f'output/{version}_{datetime.now().strftime("%Y%m%d_%H%M%S")}'
40
+ if not os.path.exists(output_dir): os.makedirs(output_dir, exist_ok=True)
41
+ accu_steps = config['accelerator']['accumulation_nsteps']
42
+ accelerator = Accelerator(
43
+ gradient_accumulation_steps=accu_steps,
44
+ mixed_precision=config['accelerator']['mixed_precision'],
45
+ log_with=config['accelerator']['report_to'],
46
+ project_config=ProjectConfiguration(project_dir=output_dir),
47
+ kwargs_handlers=[kwargs]
48
+ )
49
+ logger = get_logger(__name__, log_level='INFO')
50
+ config['env']['logger'] = logger
51
+ set_seed(config['env']['seed'])
52
+ if config['env']['verbose']:
53
+ logger.info(f'Version: {version} (from {args.config_path})')
54
+ logger.info(f'Output dir: {output_dir}')
55
+ logger.info(f'Using {accelerator.num_processes} GPU' + ('s' if accelerator.num_processes > 1 else ''))
56
+ return config, accelerator, output_dir
da2/utils/d2pc.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import utils3d
4
+ from plyfile import PlyData, PlyElement
5
+ from PIL import Image
6
+
7
+
8
+ def sphere_uv2dirs(uv: np.ndarray):
9
+ theta, phi = (1 - uv[..., 0]) * (2 * np.pi), uv[..., 1] * np.pi
10
+ directions = np.stack([np.sin(phi) * np.cos(theta), np.sin(phi) * np.sin(theta), np.cos(phi)], axis=-1)
11
+ return directions
12
+
13
+ def save_3d_points(points: np.array, colors: np.array, mask: np.array, save_path: str):
14
+ points = points.reshape(-1, 3)
15
+ colors = colors.reshape(-1, 3)
16
+ mask = mask.reshape(-1, 1)
17
+
18
+ vertex_data = np.empty(mask.sum(), dtype=[
19
+ ('x', 'f4'),
20
+ ('y', 'f4'),
21
+ ('z', 'f4'),
22
+ ('red', 'u1'),
23
+ ('green', 'u1'),
24
+ ('blue', 'u1'),
25
+ ])
26
+ vertex_data['x'] = [a for i, a in enumerate(points[:, 0]) if mask[i]]
27
+ vertex_data['y'] = [a for i, a in enumerate(points[:, 1]) if mask[i]]
28
+ vertex_data['z'] = [a for i, a in enumerate(points[:, 2]) if mask[i]]
29
+ vertex_data['red'] = [a for i, a in enumerate(colors[:, 0]) if mask[i]]
30
+ vertex_data['green'] = [a for i, a in enumerate(colors[:, 1]) if mask[i]]
31
+ vertex_data['blue'] = [a for i, a in enumerate(colors[:, 2]) if mask[i]]
32
+
33
+ vertex_element = PlyElement.describe(vertex_data, 'vertex', comments=['vertices with color'])
34
+ save_dir = os.path.dirname(save_path)
35
+ if not os.path.exists(save_dir): os.makedirs(save_dir, exist_ok=True)
36
+ PlyData([vertex_element], text=True).write(save_path)
37
+
38
+ def colorize_normal(normal: np.ndarray, normal_mask: np.ndarray):
39
+ normal_rgb = (((normal + 1) * 0.5) * 255).astype(np.uint8)
40
+ normal_mask = np.repeat(np.expand_dims(normal_mask, axis=-1), 3, axis=-1)
41
+ normal_mask = normal_mask.astype(np.uint8)
42
+ normal_rgb = normal_rgb * normal_mask
43
+ return normal_rgb
44
+
45
+ def normal_normalize(normal: np.ndarray):
46
+ normal_norm = np.linalg.norm(normal, axis=-1, keepdims=True)
47
+ normal_norm[normal_norm < 1e-6] = 1e-6
48
+ return normal / normal_norm
49
+
50
+ def distance2pointcloud(
51
+ distance: np.ndarray,
52
+ image: np.ndarray,
53
+ mask: np.ndarray,
54
+ save_path: str = None,
55
+ return_normal: bool = False,
56
+ save_distance: bool = False
57
+ ):
58
+ if distance.ndim >= 3: distance = distance.squeeze()
59
+ if save_distance:
60
+ save_path_dis = save_path.replace('3dpc', 'depth').replace('.ply', '.npy')
61
+ save_dir_dis = os.path.dirname(save_path_dis)
62
+ if not os.path.exists(save_dir_dis): os.makedirs(save_dir_dis, exist_ok=True)
63
+ np.save(save_path_dis, distance)
64
+ height, width = distance.shape[:2]
65
+ points = distance[:, :, None] * sphere_uv2dirs(utils3d.numpy.image_uv(width=width, height=height))
66
+ save_3d_points(points, image, mask, save_path)
67
+ if return_normal:
68
+ normal, normal_mask = utils3d.numpy.points_to_normals(points, mask)
69
+ normal = normal * np.array([-1, -1, 1])
70
+ normal = normal_normalize(normal)
71
+ normal_1 = normal[..., 0]
72
+ normal_2 = normal[..., 1]
73
+ normal_3 = normal[..., 2]
74
+ normal = np.stack([normal_1, normal_3, normal_2], axis=-1)
75
+ normal_img = colorize_normal(normal, normal_mask)
76
+ return Image.fromarray(normal_img)
da2/utils/io.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import torch
4
+ import numpy as np
5
+ from glob import glob
6
+ from PIL import Image
7
+
8
+
9
+ def torch_transform(image):
10
+ image = image / 255.0
11
+ image = np.transpose(image, (2, 0, 1))
12
+ return image
13
+
14
+ def read_cv2_image(image_path):
15
+ image = cv2.imread(image_path)
16
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
17
+ return image
18
+
19
+ def read_mask(mask_path, shape):
20
+ if not os.path.exists(mask_path):
21
+ return np.ones(shape[1:]) > 0
22
+ mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
23
+ mask = mask > 0
24
+ return mask
25
+
26
+ def tensorize(array, model_dtype, device):
27
+ array = torch.from_numpy(array).to(device).to(model_dtype).unsqueeze(dim=0)
28
+ return array
29
+
30
+ def load_infer_data(config, device):
31
+ image_dir = config['inference']['images']
32
+ mask_dir = config['inference']['masks']
33
+
34
+ image_paths = glob(os.path.join(image_dir, '*.png'))
35
+ image_paths = sorted(image_paths)
36
+ filenames = [os.path.basename(image_path)[:-4] for image_path in image_paths]
37
+ cv2_images = [read_cv2_image(image_path)
38
+ for image_path in image_paths]
39
+ PIL_images = [Image.fromarray(cv2_image) for cv2_image in cv2_images]
40
+ images = [torch_transform(cv2_image) for cv2_image in cv2_images]
41
+
42
+ mask_paths = [image_path.replace(image_dir, mask_dir)
43
+ for image_path in image_paths]
44
+ masks = [read_mask(mask_path, images[i].shape)
45
+ for (i, mask_path) in enumerate(mask_paths)]
46
+
47
+ model_dtype = config['spherevit']['dtype']
48
+ images = [tensorize(image, model_dtype, device) for image in images]
49
+
50
+ infer_data = {
51
+ 'images': {
52
+ 'PIL': PIL_images,
53
+ 'cv2': cv2_images,
54
+ 'torch': images
55
+ },
56
+ 'masks': masks,
57
+ 'filenames': filenames,
58
+ 'size': len(images)
59
+ }
60
+ if config['env']['verbose']:
61
+ s = 's' if len(images) > 1 else ''
62
+ config['env']['logger'].info(f'Loaded {len(images)} image{s} in {model_dtype}')
63
+ return infer_data
da2/utils/model.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from ..model.spherevit import SphereViT
3
+
4
+
5
+ def load_model(config, accelerator):
6
+ model = SphereViT.from_pretrained('haodongli/DA-2', config=config)
7
+ model = model.to(accelerator.device)
8
+ torch.cuda.empty_cache()
9
+ model = accelerator.prepare(model)
10
+ if accelerator.num_processes > 1:
11
+ model = model.module
12
+ if config['env']['verbose']:
13
+ config['env']['logger'].info(f'Model\'s dtype: {next(model.parameters()).dtype}.')
14
+ config['spherevit']['dtype'] = next(model.parameters()).dtype
15
+ return model
da2/utils/vis.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ import numpy as np
4
+ import matplotlib
5
+ import cv2
6
+
7
+
8
+ def concatenate_images(*image_lists):
9
+ max_width = 0
10
+ total_height = 0
11
+ row_widths = []
12
+ row_heights = []
13
+
14
+ for i, image_list in enumerate(image_lists):
15
+ width = sum(img.width for img in image_list)
16
+ max_width = max(max_width, width)
17
+ row_widths.append(width)
18
+ # Assuming all images in the list have the same height
19
+ height = image_list[0].height
20
+ total_height += height
21
+ row_heights.append(height)
22
+
23
+ new_image = Image.new('RGB', (max_width, total_height))
24
+ y_offset = 0
25
+ for i, image_list in enumerate(image_lists):
26
+ x_offset = 0
27
+ for img in image_list:
28
+ new_image.paste(img, (x_offset, y_offset))
29
+ x_offset += img.width
30
+ y_offset += row_heights[i]
31
+ return new_image
32
+
33
+ def colorize_distance(distance, mask, cmap='Spectral'):
34
+ if distance.ndim >= 3: distance = distance.squeeze()
35
+ cm = matplotlib.colormaps[cmap]
36
+ valid_distance = distance[mask]
37
+ max_distance = np.quantile(valid_distance, 0.98)
38
+ min_distance = np.quantile(valid_distance, 0.02)
39
+ distance[~mask] = max_distance
40
+ distance = ((distance - min_distance) / (max_distance - min_distance))
41
+ distance = np.clip(distance, 0, 1)
42
+ img_colored_np = cm(distance, bytes=False)[:, :, 0:3]
43
+ distance_colored = (img_colored_np * 255).astype(np.uint8)
44
+ return Image.fromarray(distance_colored)
requirements.txt CHANGED
@@ -1 +1,19 @@
1
- -e /tmp/src/
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.5.0
2
+ torchvision==0.20.0
3
+ torchaudio==2.5.0
4
+ xformers==0.0.28.post2
5
+ diffusers==0.32.0
6
+ tensorboard==2.18.0
7
+ git+https://github.com/EasternJournalist/utils3d.git@3913c65d81e05e47b9f367250cf8c0f7462a0900
8
+ opencv-python==4.12.0.88
9
+ gradio==5.49.0
10
+ gradio-client==1.13.3
11
+ gradio-imageslider==0.0.20
12
+ accelerate==1.1.1
13
+ omegaconf==2.3.0
14
+ tabulate==0.9.0
15
+ einops==0.8.0
16
+ timm==1.0.15
17
+ trimesh==4.5.2
18
+ transformers==4.46.3
19
+ matplotlib==3.9.2