Spaces:
Sleeping
Sleeping
update
Browse files- da2/__init__.py +25 -0
- da2/__pycache__/__init__.cpython-312.pyc +0 -0
- da2/model/__init__.py +11 -0
- da2/model/__pycache__/__init__.cpython-312.pyc +0 -0
- da2/model/__pycache__/base.cpython-312.pyc +0 -0
- da2/model/__pycache__/sphere.cpython-312.pyc +0 -0
- da2/model/__pycache__/spherevit.cpython-312.pyc +0 -0
- da2/model/__pycache__/vit_w_esphere.cpython-312.pyc +0 -0
- da2/model/base.py +393 -0
- da2/model/dinov2/__init__.py +13 -0
- da2/model/dinov2/__pycache__/__init__.cpython-312.pyc +0 -0
- da2/model/dinov2/__pycache__/attention.cpython-312.pyc +0 -0
- da2/model/dinov2/__pycache__/block.cpython-312.pyc +0 -0
- da2/model/dinov2/__pycache__/dinovit.cpython-312.pyc +0 -0
- da2/model/dinov2/__pycache__/drop_path.cpython-312.pyc +0 -0
- da2/model/dinov2/__pycache__/layer_scale.cpython-312.pyc +0 -0
- da2/model/dinov2/__pycache__/mlp.cpython-312.pyc +0 -0
- da2/model/dinov2/__pycache__/patch_embed.cpython-312.pyc +0 -0
- da2/model/dinov2/__pycache__/swiglu_ffn.cpython-312.pyc +0 -0
- da2/model/dinov2/attention.py +79 -0
- da2/model/dinov2/block.py +280 -0
- da2/model/dinov2/dino_head.py +68 -0
- da2/model/dinov2/dinovit.py +223 -0
- da2/model/dinov2/drop_path.py +37 -0
- da2/model/dinov2/layer_scale.py +28 -0
- da2/model/dinov2/mlp.py +41 -0
- da2/model/dinov2/patch_embed.py +101 -0
- da2/model/dinov2/swiglu_ffn.py +63 -0
- da2/model/sphere.py +30 -0
- da2/model/spherevit.py +69 -0
- da2/model/vit_w_esphere.py +224 -0
- da2/utils/__init__.py +11 -0
- da2/utils/__pycache__/__init__.cpython-312.pyc +0 -0
- da2/utils/__pycache__/base.cpython-312.pyc +0 -0
- da2/utils/__pycache__/d2pc.cpython-312.pyc +0 -0
- da2/utils/__pycache__/io.cpython-312.pyc +0 -0
- da2/utils/__pycache__/model.cpython-312.pyc +0 -0
- da2/utils/__pycache__/vis.cpython-312.pyc +0 -0
- da2/utils/base.py +56 -0
- da2/utils/d2pc.py +76 -0
- da2/utils/io.py +63 -0
- da2/utils/model.py +15 -0
- da2/utils/vis.py +44 -0
- requirements.txt +19 -1
da2/__init__.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .utils.base import (
|
| 2 |
+
prepare_to_run
|
| 3 |
+
)
|
| 4 |
+
from .utils.model import (
|
| 5 |
+
load_model
|
| 6 |
+
)
|
| 7 |
+
from .utils.io import (
|
| 8 |
+
load_infer_data
|
| 9 |
+
)
|
| 10 |
+
from .utils.vis import (
|
| 11 |
+
colorize_distance,
|
| 12 |
+
concatenate_images
|
| 13 |
+
)
|
| 14 |
+
from .utils.d2pc import (
|
| 15 |
+
distance2pointcloud
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
__all__ = [
|
| 19 |
+
'prepare_to_run',
|
| 20 |
+
'load_model',
|
| 21 |
+
'load_infer_data',
|
| 22 |
+
'colorize_distance',
|
| 23 |
+
'concatenate_images',
|
| 24 |
+
'distance2pointcloud'
|
| 25 |
+
]
|
da2/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (498 Bytes). View file
|
|
|
da2/model/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .spherevit import (
|
| 2 |
+
SphereViT
|
| 3 |
+
)
|
| 4 |
+
from .vit_w_esphere import (
|
| 5 |
+
ViT_w_Esphere
|
| 6 |
+
)
|
| 7 |
+
|
| 8 |
+
__all__ = [
|
| 9 |
+
'SphereViT',
|
| 10 |
+
'ViT_w_Esphere',
|
| 11 |
+
]
|
da2/model/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (298 Bytes). View file
|
|
|
da2/model/__pycache__/base.cpython-312.pyc
ADDED
|
Binary file (18.4 kB). View file
|
|
|
da2/model/__pycache__/sphere.cpython-312.pyc
ADDED
|
Binary file (2.39 kB). View file
|
|
|
da2/model/__pycache__/spherevit.cpython-312.pyc
ADDED
|
Binary file (3.96 kB). View file
|
|
|
da2/model/__pycache__/vit_w_esphere.cpython-312.pyc
ADDED
|
Binary file (10.5 kB). View file
|
|
|
da2/model/base.py
ADDED
|
@@ -0,0 +1,393 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from math import log2, pi
|
| 4 |
+
from typing import Tuple
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from einops import rearrange
|
| 7 |
+
from functools import partial
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def fourier_dimension_expansion(
|
| 11 |
+
x: torch.Tensor,
|
| 12 |
+
dim: int = 512,
|
| 13 |
+
max_freq: int = 64,
|
| 14 |
+
use_cos: bool = True,
|
| 15 |
+
use_log: bool = True,
|
| 16 |
+
):
|
| 17 |
+
device, dtype, input_dim = x.device, x.dtype, x.shape[-1]
|
| 18 |
+
# input_dim: 2
|
| 19 |
+
num_bands = dim // (2 * input_dim) if use_cos else dim // input_dim
|
| 20 |
+
# num_bands = 512 // 2 = 256
|
| 21 |
+
if use_log:
|
| 22 |
+
scales = 2.0 ** torch.linspace(
|
| 23 |
+
0.0, log2(max_freq), steps=num_bands, device=device, dtype=dtype
|
| 24 |
+
)
|
| 25 |
+
else:
|
| 26 |
+
scales = torch.linspace(
|
| 27 |
+
1.0, max_freq / 2, num_bands, device=device, dtype=dtype
|
| 28 |
+
)
|
| 29 |
+
x = x.unsqueeze(-1)
|
| 30 |
+
scales = scales[(*((None,) * (len(x.shape) - 1)), Ellipsis)]
|
| 31 |
+
x = x * scales * pi
|
| 32 |
+
x = torch.cat(
|
| 33 |
+
(
|
| 34 |
+
[x.sin(), x.cos()]
|
| 35 |
+
if use_cos
|
| 36 |
+
else [
|
| 37 |
+
x.sin(),
|
| 38 |
+
]
|
| 39 |
+
),
|
| 40 |
+
dim=-1,
|
| 41 |
+
)
|
| 42 |
+
x = x.flatten(-2)
|
| 43 |
+
return x
|
| 44 |
+
|
| 45 |
+
def flatten(
|
| 46 |
+
flat_tensor: torch.Tensor,
|
| 47 |
+
old: Tuple[int, int],
|
| 48 |
+
new: Tuple[int, int],
|
| 49 |
+
) -> torch.Tensor:
|
| 50 |
+
if old[0] == new[0] and old[1] == new[1]:
|
| 51 |
+
return flat_tensor
|
| 52 |
+
tensor = flat_tensor.view(flat_tensor.shape[0], old[0], old[1], -1).permute(
|
| 53 |
+
0, 3, 1, 2
|
| 54 |
+
) # b c h w
|
| 55 |
+
tensor_interp = F.interpolate(
|
| 56 |
+
tensor,
|
| 57 |
+
size=(new[0], new[1]),
|
| 58 |
+
mode='nearest',
|
| 59 |
+
)
|
| 60 |
+
flat_tensor_interp = tensor_interp.view(
|
| 61 |
+
flat_tensor.shape[0], -1, new[0] * new[1]
|
| 62 |
+
).permute(
|
| 63 |
+
0, 2, 1
|
| 64 |
+
) # b (h w) c
|
| 65 |
+
return flat_tensor_interp.contiguous()
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class DimensionAligner(nn.Module):
|
| 69 |
+
def __init__(self, input_dims: list[int], hidden_dim: int):
|
| 70 |
+
super().__init__()
|
| 71 |
+
self.aligners = nn.ModuleList([])
|
| 72 |
+
self.num_chunks = len(input_dims)
|
| 73 |
+
self.checkpoint = True
|
| 74 |
+
for input_dim in input_dims:
|
| 75 |
+
self.aligners.append(nn.Linear(input_dim, hidden_dim))
|
| 76 |
+
|
| 77 |
+
def forward(self, xs: torch.Tensor) -> torch.Tensor:
|
| 78 |
+
outs = [self.aligners[i](x) for i, x in enumerate(xs)]
|
| 79 |
+
return outs
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class LayerScale(nn.Module):
|
| 83 |
+
def __init__(
|
| 84 |
+
self,
|
| 85 |
+
dim: int,
|
| 86 |
+
init_values: float | torch.Tensor = 1e-5,
|
| 87 |
+
inplace: bool = False,
|
| 88 |
+
) -> None:
|
| 89 |
+
super().__init__()
|
| 90 |
+
self.inplace = inplace
|
| 91 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
| 92 |
+
|
| 93 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 94 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def exists(val):
|
| 98 |
+
return val is not None
|
| 99 |
+
|
| 100 |
+
def default(val, d):
|
| 101 |
+
if exists(val):
|
| 102 |
+
return val
|
| 103 |
+
return d() if callable(d) else d
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class SwiGLU(nn.Module):
|
| 107 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 108 |
+
x, gates = x.chunk(2, dim=-1)
|
| 109 |
+
return x * F.silu(gates)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class MLP(nn.Module):
|
| 113 |
+
def __init__(
|
| 114 |
+
self,
|
| 115 |
+
input_dim: int,
|
| 116 |
+
expansion: int = 4,
|
| 117 |
+
dropout: float = 0.0,
|
| 118 |
+
gated: bool = False,
|
| 119 |
+
output_dim: int | None = None,
|
| 120 |
+
):
|
| 121 |
+
super().__init__()
|
| 122 |
+
if gated:
|
| 123 |
+
expansion = int(expansion * 2 / 3)
|
| 124 |
+
hidden_dim = int(input_dim * expansion)
|
| 125 |
+
output_dim = default(output_dim, input_dim)
|
| 126 |
+
self.norm = nn.LayerNorm(input_dim)
|
| 127 |
+
self.proj1 = nn.Linear(input_dim, hidden_dim)
|
| 128 |
+
self.proj2 = nn.Linear(hidden_dim, output_dim)
|
| 129 |
+
self.act = nn.GELU() if not gated else SwiGLU()
|
| 130 |
+
self.dropout = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity()
|
| 131 |
+
|
| 132 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 133 |
+
x = self.norm(x)
|
| 134 |
+
x = self.proj1(x)
|
| 135 |
+
x = self.act(x)
|
| 136 |
+
x = self.proj2(x)
|
| 137 |
+
x = self.dropout(x)
|
| 138 |
+
return x
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
class AttentionBlock(nn.Module):
|
| 142 |
+
def __init__(
|
| 143 |
+
self,
|
| 144 |
+
dim: int,
|
| 145 |
+
num_heads: int = 4,
|
| 146 |
+
expansion: int = 4,
|
| 147 |
+
dropout: float = 0.0,
|
| 148 |
+
cosine: bool = False,
|
| 149 |
+
gated: bool = False,
|
| 150 |
+
layer_scale: float = 1.0,
|
| 151 |
+
context_dim: int | None = None,
|
| 152 |
+
detach_query: bool = False,
|
| 153 |
+
residual_ls: bool = False,
|
| 154 |
+
):
|
| 155 |
+
super().__init__()
|
| 156 |
+
self.dropout = dropout
|
| 157 |
+
self.num_heads = num_heads
|
| 158 |
+
self.hidden_dim = dim
|
| 159 |
+
context_dim = dim if context_dim is None else context_dim
|
| 160 |
+
self.mlp = MLP(dim, expansion=expansion, dropout=dropout, gated=gated)
|
| 161 |
+
self.kv = nn.Linear(context_dim, dim * 2, bias=False)
|
| 162 |
+
self.q = nn.Linear(dim, dim, bias=False)
|
| 163 |
+
self.norm_attnx = nn.LayerNorm(dim)
|
| 164 |
+
self.norm_attnctx = nn.LayerNorm(context_dim)
|
| 165 |
+
self.cosine = cosine
|
| 166 |
+
self.out = nn.Linear(dim, dim, bias=False)
|
| 167 |
+
self.ls1_1 = (
|
| 168 |
+
LayerScale(dim, layer_scale)
|
| 169 |
+
if layer_scale > 0.0 and not residual_ls
|
| 170 |
+
else nn.Identity()
|
| 171 |
+
)
|
| 172 |
+
self.ls1_2 = (
|
| 173 |
+
LayerScale(dim, layer_scale)
|
| 174 |
+
if layer_scale > 0.0 and residual_ls
|
| 175 |
+
else nn.Identity()
|
| 176 |
+
)
|
| 177 |
+
self.ls2 = LayerScale(dim, layer_scale) if layer_scale > 0.0 else nn.Identity()
|
| 178 |
+
self.detach_query = detach_query
|
| 179 |
+
|
| 180 |
+
def attn(
|
| 181 |
+
self,
|
| 182 |
+
x: torch.Tensor,
|
| 183 |
+
attn_bias: torch.Tensor | None = None,
|
| 184 |
+
context: torch.Tensor | None = None,
|
| 185 |
+
pos_embed: torch.Tensor | None = None,
|
| 186 |
+
pos_embed_context: torch.Tensor | None = None,
|
| 187 |
+
rope: nn.Module | None = None,
|
| 188 |
+
rope_pos: torch.Tensor | None = None,
|
| 189 |
+
) -> torch.Tensor:
|
| 190 |
+
if self.detach_query:
|
| 191 |
+
x = x.detach()
|
| 192 |
+
x = self.norm_attnx(x)
|
| 193 |
+
context = self.norm_attnctx(context)
|
| 194 |
+
k, v = rearrange(
|
| 195 |
+
self.kv(context), 'b n (kv h d) -> b h n d kv', h=self.num_heads, kv=2
|
| 196 |
+
).unbind(dim=-1)
|
| 197 |
+
q = rearrange(self.q(x), 'b n (h d) -> b h n d', h=self.num_heads)
|
| 198 |
+
|
| 199 |
+
if rope is not None:
|
| 200 |
+
q = rope(q.permute(0, 2, 1, 3), input_pos=rope_pos).permute(0, 2, 1, 3)
|
| 201 |
+
k = rope(k.permute(0, 2, 1, 3), input_pos=rope_pos).permute(0, 2, 1, 3)
|
| 202 |
+
else:
|
| 203 |
+
if pos_embed is not None:
|
| 204 |
+
pos_embed = rearrange(
|
| 205 |
+
pos_embed, 'b n (h d) -> b h n d', h=self.num_heads
|
| 206 |
+
)
|
| 207 |
+
q = q + pos_embed
|
| 208 |
+
if pos_embed_context is not None:
|
| 209 |
+
pos_embed_context = rearrange(
|
| 210 |
+
pos_embed_context, 'b n (h d) -> b h n d', h=self.num_heads
|
| 211 |
+
)
|
| 212 |
+
k = k + pos_embed_context
|
| 213 |
+
|
| 214 |
+
if self.cosine:
|
| 215 |
+
q, k = map(partial(F.normalize, p=2, dim=-1), (q, k)) # cosine sim
|
| 216 |
+
|
| 217 |
+
x = F.scaled_dot_product_attention(
|
| 218 |
+
q, k, v, dropout_p=self.dropout, attn_mask=attn_bias
|
| 219 |
+
)
|
| 220 |
+
x = rearrange(x, 'b h n d -> b n (h d)')
|
| 221 |
+
x = self.out(x)
|
| 222 |
+
return x
|
| 223 |
+
|
| 224 |
+
def forward(
|
| 225 |
+
self,
|
| 226 |
+
x: torch.Tensor,
|
| 227 |
+
context: torch.Tensor | None = None,
|
| 228 |
+
pos_embed: torch.Tensor | None = None,
|
| 229 |
+
pos_embed_context: torch.Tensor | None = None,
|
| 230 |
+
attn_bias: torch.Tensor | None = None,
|
| 231 |
+
rope: nn.Module | None = None,
|
| 232 |
+
rope_pos: torch.Tensor | None = None,
|
| 233 |
+
) -> torch.Tensor:
|
| 234 |
+
context = x if context is None else context
|
| 235 |
+
x = self.ls1_1(
|
| 236 |
+
self.attn(
|
| 237 |
+
x,
|
| 238 |
+
rope=rope,
|
| 239 |
+
rope_pos=rope_pos,
|
| 240 |
+
attn_bias=attn_bias,
|
| 241 |
+
context=context,
|
| 242 |
+
pos_embed=pos_embed,
|
| 243 |
+
pos_embed_context=pos_embed_context,
|
| 244 |
+
)
|
| 245 |
+
) + self.ls1_2(x)
|
| 246 |
+
x = self.ls2(self.mlp(x)) + x
|
| 247 |
+
return x
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
class AttentionSeq(nn.Module):
|
| 251 |
+
def __init__(
|
| 252 |
+
self,
|
| 253 |
+
num_blocks: int,
|
| 254 |
+
dim: int,
|
| 255 |
+
num_heads: int = 4,
|
| 256 |
+
expansion: int = 4,
|
| 257 |
+
dropout: float = 0.0,
|
| 258 |
+
cosine: bool = False,
|
| 259 |
+
gated: bool = False,
|
| 260 |
+
layer_scale: float = 1.0,
|
| 261 |
+
context_dim: int | None = None,
|
| 262 |
+
detach_query: bool = False,
|
| 263 |
+
residual_ls: bool = False,
|
| 264 |
+
):
|
| 265 |
+
super().__init__()
|
| 266 |
+
self.layers = nn.ModuleList(
|
| 267 |
+
[
|
| 268 |
+
AttentionBlock(
|
| 269 |
+
dim=dim,
|
| 270 |
+
num_heads=num_heads,
|
| 271 |
+
expansion=expansion,
|
| 272 |
+
dropout=dropout,
|
| 273 |
+
cosine=cosine,
|
| 274 |
+
gated=gated,
|
| 275 |
+
layer_scale=layer_scale,
|
| 276 |
+
context_dim=context_dim,
|
| 277 |
+
detach_query=detach_query,
|
| 278 |
+
residual_ls=residual_ls,
|
| 279 |
+
)
|
| 280 |
+
for _ in range(num_blocks)
|
| 281 |
+
]
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
def forward(
|
| 285 |
+
self,
|
| 286 |
+
x: torch.Tensor,
|
| 287 |
+
context: torch.Tensor | None = None,
|
| 288 |
+
pos_embed: torch.Tensor | None = None,
|
| 289 |
+
pos_embed_context: torch.Tensor | None = None,
|
| 290 |
+
attn_bias: torch.Tensor | None = None,
|
| 291 |
+
rope: nn.Module | None = None,
|
| 292 |
+
rope_pos: torch.Tensor | None = None,
|
| 293 |
+
) -> torch.Tensor:
|
| 294 |
+
for layer in self.layers:
|
| 295 |
+
x = layer(
|
| 296 |
+
x,
|
| 297 |
+
context=context,
|
| 298 |
+
pos_embed=pos_embed,
|
| 299 |
+
pos_embed_context=pos_embed_context,
|
| 300 |
+
attn_bias=attn_bias,
|
| 301 |
+
rope=rope,
|
| 302 |
+
rope_pos=rope_pos,
|
| 303 |
+
)
|
| 304 |
+
return x
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
class ResidualConvNet(nn.Module):
|
| 308 |
+
def __init__(
|
| 309 |
+
self,
|
| 310 |
+
dim,
|
| 311 |
+
kernel_size: int = 3,
|
| 312 |
+
padding_mode: str = 'zeros',
|
| 313 |
+
dilation: int = 1,
|
| 314 |
+
layer_scale: float = 1.0,
|
| 315 |
+
use_norm: bool = False,
|
| 316 |
+
):
|
| 317 |
+
super().__init__()
|
| 318 |
+
self.conv1 = nn.Conv2d(
|
| 319 |
+
dim,
|
| 320 |
+
dim,
|
| 321 |
+
kernel_size=kernel_size,
|
| 322 |
+
padding=dilation * (kernel_size - 1) // 2,
|
| 323 |
+
dilation=dilation,
|
| 324 |
+
padding_mode=padding_mode,
|
| 325 |
+
)
|
| 326 |
+
self.conv2 = nn.Conv2d(
|
| 327 |
+
dim,
|
| 328 |
+
dim,
|
| 329 |
+
kernel_size=kernel_size,
|
| 330 |
+
padding=dilation * (kernel_size - 1) // 2,
|
| 331 |
+
dilation=dilation,
|
| 332 |
+
padding_mode=padding_mode,
|
| 333 |
+
)
|
| 334 |
+
self.activation = nn.LeakyReLU()
|
| 335 |
+
self.gamma = (
|
| 336 |
+
nn.Parameter(layer_scale * torch.ones(1, dim, 1, 1))
|
| 337 |
+
if layer_scale > 0.0
|
| 338 |
+
else 1.0
|
| 339 |
+
)
|
| 340 |
+
self.norm1 = nn.GroupNorm(dim // 16, dim) if use_norm else nn.Identity()
|
| 341 |
+
self.norm2 = nn.GroupNorm(dim // 16, dim) if use_norm else nn.Identity()
|
| 342 |
+
|
| 343 |
+
def forward(self, x):
|
| 344 |
+
out = self.activation(x)
|
| 345 |
+
out = self.conv1(out)
|
| 346 |
+
out = self.norm1(out)
|
| 347 |
+
out = self.activation(out)
|
| 348 |
+
out = self.conv2(out)
|
| 349 |
+
out = self.norm2(out)
|
| 350 |
+
return self.gamma * out + x
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
class ResidualUpsampler(nn.Module):
|
| 354 |
+
def __init__(
|
| 355 |
+
self,
|
| 356 |
+
hidden_dim,
|
| 357 |
+
output_dim: int = None,
|
| 358 |
+
num_layers: int = 2,
|
| 359 |
+
kernel_size: int = 3,
|
| 360 |
+
layer_scale: float = 1.0,
|
| 361 |
+
padding_mode: str = 'zeros',
|
| 362 |
+
use_norm: bool = False,
|
| 363 |
+
**kwargs,
|
| 364 |
+
):
|
| 365 |
+
super().__init__()
|
| 366 |
+
output_dim = output_dim if output_dim is not None else hidden_dim // 2
|
| 367 |
+
self.convs = nn.ModuleList([])
|
| 368 |
+
for _ in range(num_layers):
|
| 369 |
+
self.convs.append(
|
| 370 |
+
ResidualConvNet(
|
| 371 |
+
hidden_dim,
|
| 372 |
+
kernel_size=kernel_size,
|
| 373 |
+
layer_scale=layer_scale,
|
| 374 |
+
padding_mode=padding_mode,
|
| 375 |
+
use_norm=use_norm,
|
| 376 |
+
)
|
| 377 |
+
)
|
| 378 |
+
self.up = nn.Sequential(
|
| 379 |
+
nn.Conv2d(
|
| 380 |
+
hidden_dim,
|
| 381 |
+
output_dim,
|
| 382 |
+
kernel_size=1,
|
| 383 |
+
padding=0,
|
| 384 |
+
padding_mode=padding_mode,
|
| 385 |
+
),
|
| 386 |
+
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
def forward(self, x: torch.Tensor):
|
| 390 |
+
for conv in self.convs:
|
| 391 |
+
x = conv(x)
|
| 392 |
+
x = self.up(x)
|
| 393 |
+
return x
|
da2/model/dinov2/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from .dinovit import (
|
| 8 |
+
DINOViT
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
__all__ = [
|
| 12 |
+
'DINOViT'
|
| 13 |
+
]
|
da2/model/dinov2/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (237 Bytes). View file
|
|
|
da2/model/dinov2/__pycache__/attention.cpython-312.pyc
ADDED
|
Binary file (4.13 kB). View file
|
|
|
da2/model/dinov2/__pycache__/block.cpython-312.pyc
ADDED
|
Binary file (13.5 kB). View file
|
|
|
da2/model/dinov2/__pycache__/dinovit.cpython-312.pyc
ADDED
|
Binary file (9.6 kB). View file
|
|
|
da2/model/dinov2/__pycache__/drop_path.cpython-312.pyc
ADDED
|
Binary file (1.63 kB). View file
|
|
|
da2/model/dinov2/__pycache__/layer_scale.cpython-312.pyc
ADDED
|
Binary file (1.39 kB). View file
|
|
|
da2/model/dinov2/__pycache__/mlp.cpython-312.pyc
ADDED
|
Binary file (1.92 kB). View file
|
|
|
da2/model/dinov2/__pycache__/patch_embed.cpython-312.pyc
ADDED
|
Binary file (4.07 kB). View file
|
|
|
da2/model/dinov2/__pycache__/swiglu_ffn.cpython-312.pyc
ADDED
|
Binary file (2.81 kB). View file
|
|
|
da2/model/dinov2/attention.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# References:
|
| 8 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 9 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
| 10 |
+
|
| 11 |
+
import logging
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger("dinov2")
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
try:
|
| 21 |
+
from xformers.ops import fmha, memory_efficient_attention, unbind
|
| 22 |
+
|
| 23 |
+
XFORMERS_AVAILABLE = True
|
| 24 |
+
except ImportError:
|
| 25 |
+
logger.warning("xFormers not available")
|
| 26 |
+
XFORMERS_AVAILABLE = False
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class Attention(nn.Module):
|
| 30 |
+
def __init__(
|
| 31 |
+
self,
|
| 32 |
+
dim: int,
|
| 33 |
+
num_heads: int = 8,
|
| 34 |
+
qkv_bias: bool = False,
|
| 35 |
+
proj_bias: bool = True,
|
| 36 |
+
attn_drop: float = 0.0,
|
| 37 |
+
proj_drop: float = 0.0,
|
| 38 |
+
) -> None:
|
| 39 |
+
super().__init__()
|
| 40 |
+
self.num_heads = num_heads
|
| 41 |
+
head_dim = dim // num_heads
|
| 42 |
+
self.scale = head_dim**-0.5
|
| 43 |
+
|
| 44 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 45 |
+
self.attn_drop = nn.Dropout(attn_drop) if attn_drop > 0.0 else nn.Identity()
|
| 46 |
+
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
| 47 |
+
self.proj_drop = nn.Dropout(proj_drop) if proj_drop > 0.0 else nn.Identity()
|
| 48 |
+
|
| 49 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 50 |
+
B, N, C = x.shape
|
| 51 |
+
qkv = (
|
| 52 |
+
self.qkv(x)
|
| 53 |
+
.reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
| 54 |
+
.permute(2, 0, 3, 1, 4)
|
| 55 |
+
)
|
| 56 |
+
x = F.scaled_dot_product_attention(qkv[0], qkv[1], qkv[2])
|
| 57 |
+
x = x.transpose(1, 2).reshape(B, N, C)
|
| 58 |
+
x = self.proj(x)
|
| 59 |
+
x = self.proj_drop(x)
|
| 60 |
+
return x
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class MemEffAttention(Attention):
|
| 64 |
+
def forward(self, x: torch.Tensor, attn_bias=None) -> torch.Tensor:
|
| 65 |
+
if not XFORMERS_AVAILABLE or x.device.type == "cpu":
|
| 66 |
+
assert attn_bias is None, "xFormers is required for nested tensors usage"
|
| 67 |
+
return super().forward(x)
|
| 68 |
+
|
| 69 |
+
B, N, C = x.shape
|
| 70 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
| 71 |
+
|
| 72 |
+
q, k, v = unbind(qkv, 2)
|
| 73 |
+
|
| 74 |
+
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
|
| 75 |
+
x = x.reshape([B, N, C])
|
| 76 |
+
|
| 77 |
+
x = self.proj(x)
|
| 78 |
+
x = self.proj_drop(x)
|
| 79 |
+
return x
|
da2/model/dinov2/block.py
ADDED
|
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# References:
|
| 8 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 9 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
| 10 |
+
|
| 11 |
+
import logging
|
| 12 |
+
from typing import Any, Callable, Dict, List, Tuple
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
|
| 17 |
+
from .attention import Attention, MemEffAttention
|
| 18 |
+
from .drop_path import DropPath
|
| 19 |
+
from .layer_scale import LayerScale
|
| 20 |
+
from .mlp import Mlp
|
| 21 |
+
|
| 22 |
+
logger = logging.getLogger("dinov2")
|
| 23 |
+
|
| 24 |
+
try:
|
| 25 |
+
from xformers.ops import fmha, index_select_cat, scaled_index_add
|
| 26 |
+
|
| 27 |
+
XFORMERS_AVAILABLE = True
|
| 28 |
+
except ImportError:
|
| 29 |
+
logger.warning("xFormers not available")
|
| 30 |
+
XFORMERS_AVAILABLE = False
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class Block(nn.Module):
|
| 34 |
+
def __init__(
|
| 35 |
+
self,
|
| 36 |
+
dim: int,
|
| 37 |
+
num_heads: int,
|
| 38 |
+
mlp_ratio: float = 4.0,
|
| 39 |
+
qkv_bias: bool = False,
|
| 40 |
+
proj_bias: bool = True,
|
| 41 |
+
ffn_bias: bool = True,
|
| 42 |
+
drop: float = 0.0,
|
| 43 |
+
attn_drop: float = 0.0,
|
| 44 |
+
init_values=None,
|
| 45 |
+
drop_path: float = 0.0,
|
| 46 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
| 47 |
+
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
| 48 |
+
attn_class: Callable[..., nn.Module] = Attention,
|
| 49 |
+
ffn_layer: Callable[..., nn.Module] = Mlp,
|
| 50 |
+
) -> None:
|
| 51 |
+
super().__init__()
|
| 52 |
+
self.norm1 = norm_layer(dim)
|
| 53 |
+
self.attn = attn_class(
|
| 54 |
+
dim,
|
| 55 |
+
num_heads=num_heads,
|
| 56 |
+
qkv_bias=qkv_bias,
|
| 57 |
+
proj_bias=proj_bias,
|
| 58 |
+
attn_drop=attn_drop,
|
| 59 |
+
proj_drop=drop,
|
| 60 |
+
)
|
| 61 |
+
self.ls1 = (
|
| 62 |
+
LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 63 |
+
)
|
| 64 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 65 |
+
|
| 66 |
+
self.norm2 = norm_layer(dim)
|
| 67 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 68 |
+
self.mlp = ffn_layer(
|
| 69 |
+
in_features=dim,
|
| 70 |
+
hidden_features=mlp_hidden_dim,
|
| 71 |
+
act_layer=act_layer,
|
| 72 |
+
drop=drop,
|
| 73 |
+
bias=ffn_bias,
|
| 74 |
+
)
|
| 75 |
+
self.ls2 = (
|
| 76 |
+
LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 77 |
+
)
|
| 78 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 79 |
+
|
| 80 |
+
self.sample_drop_ratio = drop_path
|
| 81 |
+
|
| 82 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 83 |
+
def attn_residual_func(x: torch.Tensor) -> torch.Tensor:
|
| 84 |
+
return self.ls1(self.attn(self.norm1(x)))
|
| 85 |
+
|
| 86 |
+
def ffn_residual_func(x: torch.Tensor) -> torch.Tensor:
|
| 87 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
| 88 |
+
|
| 89 |
+
if self.training and self.sample_drop_ratio > 0.1:
|
| 90 |
+
# the overhead is compensated only for a drop path rate larger than 0.1
|
| 91 |
+
x = drop_add_residual_stochastic_depth(
|
| 92 |
+
x,
|
| 93 |
+
residual_func=attn_residual_func,
|
| 94 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 95 |
+
)
|
| 96 |
+
x = drop_add_residual_stochastic_depth(
|
| 97 |
+
x,
|
| 98 |
+
residual_func=ffn_residual_func,
|
| 99 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 100 |
+
)
|
| 101 |
+
elif self.training and self.sample_drop_ratio > 0.0:
|
| 102 |
+
x = x + self.drop_path1(attn_residual_func(x))
|
| 103 |
+
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
|
| 104 |
+
else:
|
| 105 |
+
x = x + attn_residual_func(x)
|
| 106 |
+
x = x + ffn_residual_func(x)
|
| 107 |
+
return x
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def drop_add_residual_stochastic_depth(
|
| 111 |
+
x: torch.Tensor,
|
| 112 |
+
residual_func, #: Callable[[torch.Tensor], torch.Tensor],
|
| 113 |
+
sample_drop_ratio: float = 0.0,
|
| 114 |
+
) -> torch.Tensor:
|
| 115 |
+
# 1) extract subset using permutation
|
| 116 |
+
b, n, d = x.shape
|
| 117 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
| 118 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
| 119 |
+
x_subset = x[brange]
|
| 120 |
+
|
| 121 |
+
# 2) apply residual_func to get residual
|
| 122 |
+
residual = residual_func(x_subset)
|
| 123 |
+
|
| 124 |
+
x_flat = x.flatten(1)
|
| 125 |
+
residual = residual.flatten(1)
|
| 126 |
+
|
| 127 |
+
residual_scale_factor = b / sample_subset_size
|
| 128 |
+
|
| 129 |
+
# 3) add the residual
|
| 130 |
+
x_plus_residual = torch.index_add(
|
| 131 |
+
x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor
|
| 132 |
+
)
|
| 133 |
+
return x_plus_residual.view_as(x)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def get_branges_scales(x, sample_drop_ratio=0.0):
|
| 137 |
+
b, n, d = x.shape
|
| 138 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
| 139 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
| 140 |
+
residual_scale_factor = b / sample_subset_size
|
| 141 |
+
return brange, residual_scale_factor
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
|
| 145 |
+
if scaling_vector is None:
|
| 146 |
+
x_flat = x.flatten(1)
|
| 147 |
+
residual = residual.flatten(1)
|
| 148 |
+
x_plus_residual = torch.index_add(
|
| 149 |
+
x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor
|
| 150 |
+
)
|
| 151 |
+
else:
|
| 152 |
+
x_plus_residual = scaled_index_add(
|
| 153 |
+
x,
|
| 154 |
+
brange,
|
| 155 |
+
residual.to(dtype=x.dtype),
|
| 156 |
+
scaling=scaling_vector,
|
| 157 |
+
alpha=residual_scale_factor,
|
| 158 |
+
)
|
| 159 |
+
return x_plus_residual
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
attn_bias_cache: Dict[Tuple, Any] = {}
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def get_attn_bias_and_cat(x_list, branges=None):
|
| 166 |
+
"""
|
| 167 |
+
this will perform the index select, cat the tensors, and provide the attn_bias from cache
|
| 168 |
+
"""
|
| 169 |
+
batch_sizes = (
|
| 170 |
+
[b.shape[0] for b in branges]
|
| 171 |
+
if branges is not None
|
| 172 |
+
else [x.shape[0] for x in x_list]
|
| 173 |
+
)
|
| 174 |
+
all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
|
| 175 |
+
if all_shapes not in attn_bias_cache.keys():
|
| 176 |
+
seqlens = []
|
| 177 |
+
for b, x in zip(batch_sizes, x_list):
|
| 178 |
+
for _ in range(b):
|
| 179 |
+
seqlens.append(x.shape[1])
|
| 180 |
+
attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
|
| 181 |
+
attn_bias._batch_sizes = batch_sizes
|
| 182 |
+
attn_bias_cache[all_shapes] = attn_bias
|
| 183 |
+
|
| 184 |
+
if branges is not None:
|
| 185 |
+
cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(
|
| 186 |
+
1, -1, x_list[0].shape[-1]
|
| 187 |
+
)
|
| 188 |
+
else:
|
| 189 |
+
tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
|
| 190 |
+
cat_tensors = torch.cat(tensors_bs1, dim=1)
|
| 191 |
+
|
| 192 |
+
return attn_bias_cache[all_shapes], cat_tensors
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def drop_add_residual_stochastic_depth_list(
|
| 196 |
+
x_list: List[torch.Tensor],
|
| 197 |
+
residual_func, #: Callable[[torch.Tensor, Any], torch.Tensor],
|
| 198 |
+
sample_drop_ratio: float = 0.0,
|
| 199 |
+
scaling_vector=None,
|
| 200 |
+
) -> torch.Tensor:
|
| 201 |
+
# 1) generate random set of indices for dropping samples in the batch
|
| 202 |
+
branges_scales = [
|
| 203 |
+
get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list
|
| 204 |
+
]
|
| 205 |
+
branges = [s[0] for s in branges_scales]
|
| 206 |
+
residual_scale_factors = [s[1] for s in branges_scales]
|
| 207 |
+
|
| 208 |
+
# 2) get attention bias and index+concat the tensors
|
| 209 |
+
attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
|
| 210 |
+
|
| 211 |
+
# 3) apply residual_func to get residual, and split the result
|
| 212 |
+
residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
|
| 213 |
+
|
| 214 |
+
outputs = []
|
| 215 |
+
for x, brange, residual, residual_scale_factor in zip(
|
| 216 |
+
x_list, branges, residual_list, residual_scale_factors
|
| 217 |
+
):
|
| 218 |
+
outputs.append(
|
| 219 |
+
add_residual(
|
| 220 |
+
x, brange, residual, residual_scale_factor, scaling_vector
|
| 221 |
+
).view_as(x)
|
| 222 |
+
)
|
| 223 |
+
return outputs
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
class NestedTensorBlock(Block):
|
| 227 |
+
def forward_nested(self, x_list: List[torch.Tensor]) -> List[torch.Tensor]:
|
| 228 |
+
"""
|
| 229 |
+
x_list contains a list of tensors to nest together and run
|
| 230 |
+
"""
|
| 231 |
+
assert isinstance(self.attn, MemEffAttention)
|
| 232 |
+
|
| 233 |
+
if self.training and self.sample_drop_ratio > 0.0:
|
| 234 |
+
|
| 235 |
+
def attn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:
|
| 236 |
+
return self.attn(self.norm1(x), attn_bias=attn_bias)
|
| 237 |
+
|
| 238 |
+
def ffn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:
|
| 239 |
+
return self.mlp(self.norm2(x))
|
| 240 |
+
|
| 241 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
| 242 |
+
x_list,
|
| 243 |
+
residual_func=attn_residual_func,
|
| 244 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 245 |
+
scaling_vector=(
|
| 246 |
+
self.ls1.gamma if isinstance(self.ls1, LayerScale) else None
|
| 247 |
+
),
|
| 248 |
+
)
|
| 249 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
| 250 |
+
x_list,
|
| 251 |
+
residual_func=ffn_residual_func,
|
| 252 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 253 |
+
scaling_vector=(
|
| 254 |
+
self.ls2.gamma if isinstance(self.ls1, LayerScale) else None
|
| 255 |
+
),
|
| 256 |
+
)
|
| 257 |
+
return x_list
|
| 258 |
+
else:
|
| 259 |
+
|
| 260 |
+
def attn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:
|
| 261 |
+
return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
|
| 262 |
+
|
| 263 |
+
def ffn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:
|
| 264 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
| 265 |
+
|
| 266 |
+
attn_bias, x = get_attn_bias_and_cat(x_list)
|
| 267 |
+
x = x + attn_residual_func(x, attn_bias=attn_bias)
|
| 268 |
+
x = x + ffn_residual_func(x)
|
| 269 |
+
return attn_bias.split(x)
|
| 270 |
+
|
| 271 |
+
def forward(self, x_or_x_list):
|
| 272 |
+
if isinstance(x_or_x_list, torch.Tensor):
|
| 273 |
+
return super().forward(x_or_x_list)
|
| 274 |
+
elif isinstance(x_or_x_list, list):
|
| 275 |
+
assert (
|
| 276 |
+
XFORMERS_AVAILABLE
|
| 277 |
+
), "Please install xFormers for nested tensors usage"
|
| 278 |
+
return self.forward_nested(x_or_x_list)
|
| 279 |
+
else:
|
| 280 |
+
raise AssertionError
|
da2/model/dinov2/dino_head.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from torch.nn.init import trunc_normal_
|
| 10 |
+
from torch.nn.utils import weight_norm
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class DINOHead(nn.Module):
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
in_dim,
|
| 17 |
+
out_dim,
|
| 18 |
+
use_bn=False,
|
| 19 |
+
nlayers=3,
|
| 20 |
+
hidden_dim=2048,
|
| 21 |
+
bottleneck_dim=256,
|
| 22 |
+
mlp_bias=True,
|
| 23 |
+
):
|
| 24 |
+
super().__init__()
|
| 25 |
+
nlayers = max(nlayers, 1)
|
| 26 |
+
self.mlp = _build_mlp(
|
| 27 |
+
nlayers,
|
| 28 |
+
in_dim,
|
| 29 |
+
bottleneck_dim,
|
| 30 |
+
hidden_dim=hidden_dim,
|
| 31 |
+
use_bn=use_bn,
|
| 32 |
+
bias=mlp_bias,
|
| 33 |
+
)
|
| 34 |
+
self.apply(self._init_weights)
|
| 35 |
+
self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
|
| 36 |
+
self.last_layer.weight_g.data.fill_(1)
|
| 37 |
+
|
| 38 |
+
def _init_weights(self, m):
|
| 39 |
+
if isinstance(m, nn.Linear):
|
| 40 |
+
trunc_normal_(m.weight, std=0.02)
|
| 41 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 42 |
+
nn.init.constant_(m.bias, 0)
|
| 43 |
+
|
| 44 |
+
def forward(self, x):
|
| 45 |
+
x = self.mlp(x)
|
| 46 |
+
eps = 1e-6 if x.dtype == torch.float16 else 1e-12
|
| 47 |
+
x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
|
| 48 |
+
x = self.last_layer(x)
|
| 49 |
+
return x
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def _build_mlp(
|
| 53 |
+
nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True
|
| 54 |
+
):
|
| 55 |
+
if nlayers == 1:
|
| 56 |
+
return nn.Linear(in_dim, bottleneck_dim, bias=bias)
|
| 57 |
+
else:
|
| 58 |
+
layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
|
| 59 |
+
if use_bn:
|
| 60 |
+
layers.append(nn.BatchNorm1d(hidden_dim))
|
| 61 |
+
layers.append(nn.GELU())
|
| 62 |
+
for _ in range(nlayers - 2):
|
| 63 |
+
layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
|
| 64 |
+
if use_bn:
|
| 65 |
+
layers.append(nn.BatchNorm1d(hidden_dim))
|
| 66 |
+
layers.append(nn.GELU())
|
| 67 |
+
layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
|
| 68 |
+
return nn.Sequential(*layers)
|
da2/model/dinov2/dinovit.py
ADDED
|
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# References:
|
| 8 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 9 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
| 10 |
+
|
| 11 |
+
import logging
|
| 12 |
+
|
| 13 |
+
import math
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
import contextlib
|
| 17 |
+
from functools import partial
|
| 18 |
+
from typing import Sequence
|
| 19 |
+
from .block import (
|
| 20 |
+
Block
|
| 21 |
+
)
|
| 22 |
+
from .attention import (
|
| 23 |
+
MemEffAttention
|
| 24 |
+
)
|
| 25 |
+
from .mlp import (
|
| 26 |
+
Mlp
|
| 27 |
+
)
|
| 28 |
+
from .patch_embed import (
|
| 29 |
+
PatchEmbed
|
| 30 |
+
)
|
| 31 |
+
from .swiglu_ffn import (
|
| 32 |
+
SwiGLUFFNFused
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
logger = logging.getLogger("dinov2")
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
try:
|
| 39 |
+
from xformers.ops import fmha, memory_efficient_attention, unbind
|
| 40 |
+
|
| 41 |
+
XFORMERS_AVAILABLE = True
|
| 42 |
+
except ImportError:
|
| 43 |
+
logger.warning("xFormers not available")
|
| 44 |
+
XFORMERS_AVAILABLE = False
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class DINOViT(nn.Module):
|
| 48 |
+
def __init__(
|
| 49 |
+
self,
|
| 50 |
+
img_size=518,
|
| 51 |
+
patch_size=14,
|
| 52 |
+
in_chans=3,
|
| 53 |
+
embed_dim=1024,
|
| 54 |
+
depth=24,
|
| 55 |
+
num_heads=16,
|
| 56 |
+
mlp_ratio=4,
|
| 57 |
+
qkv_bias=True,
|
| 58 |
+
ffn_bias=True,
|
| 59 |
+
proj_bias=True,
|
| 60 |
+
drop_path_rate=0.0,
|
| 61 |
+
drop_path_uniform=False,
|
| 62 |
+
init_values=1.0,
|
| 63 |
+
embed_layer=PatchEmbed,
|
| 64 |
+
act_layer=nn.GELU,
|
| 65 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
| 66 |
+
ffn_layer="mlp",
|
| 67 |
+
block_chunks=0,
|
| 68 |
+
output_idx=[6, 12, 18, 24],
|
| 69 |
+
num_register_tokens=0,
|
| 70 |
+
interpolate_antialias=False,
|
| 71 |
+
use_norm=True,
|
| 72 |
+
frozen_stages=0,
|
| 73 |
+
):
|
| 74 |
+
"""
|
| 75 |
+
Args:
|
| 76 |
+
img_size (int, tuple): input image size
|
| 77 |
+
patch_size (int, tuple): patch size
|
| 78 |
+
in_chans (int): number of input channels
|
| 79 |
+
embed_dim (int): embedding dimension
|
| 80 |
+
depth (int): depth of transformer
|
| 81 |
+
num_heads (int): number of attention heads
|
| 82 |
+
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
| 83 |
+
qkv_bias (bool): enable bias for qkv if True
|
| 84 |
+
proj_bias (bool): enable bias for proj in attn if True
|
| 85 |
+
ffn_bias (bool): enable bias for ffn if True
|
| 86 |
+
drop_path_rate (float): stochastic depth rate
|
| 87 |
+
drop_path_uniform (bool): apply uniform drop rate across blocks
|
| 88 |
+
weight_init (str): weight init scheme
|
| 89 |
+
init_values (float): layer-scale init values
|
| 90 |
+
embed_layer (nn.Module): patch embedding layer
|
| 91 |
+
act_layer (nn.Module): MLP activation layer
|
| 92 |
+
block_fn (nn.Module): transformer block class
|
| 93 |
+
ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
|
| 94 |
+
block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
|
| 95 |
+
"""
|
| 96 |
+
super().__init__()
|
| 97 |
+
self.frozen_stages = frozen_stages
|
| 98 |
+
self.patch_size = patch_size
|
| 99 |
+
self.output_idx = output_idx
|
| 100 |
+
self.num_register_tokens = num_register_tokens
|
| 101 |
+
self.interpolate_antialias = interpolate_antialias
|
| 102 |
+
|
| 103 |
+
self.patch_embed = PatchEmbed(
|
| 104 |
+
img_size=img_size,
|
| 105 |
+
patch_size=patch_size,
|
| 106 |
+
in_chans=in_chans,
|
| 107 |
+
embed_dim=embed_dim,
|
| 108 |
+
)
|
| 109 |
+
num_patches = self.patch_embed.num_patches
|
| 110 |
+
|
| 111 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
| 112 |
+
self.pos_embed = nn.Parameter(
|
| 113 |
+
torch.zeros(1, num_patches + 1, embed_dim)
|
| 114 |
+
)
|
| 115 |
+
assert num_register_tokens >= 0
|
| 116 |
+
self.register_tokens = nn.Parameter(
|
| 117 |
+
torch.zeros(1, max(1, num_register_tokens), embed_dim)
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
if drop_path_uniform is True:
|
| 121 |
+
dpr = [drop_path_rate] * depth
|
| 122 |
+
else:
|
| 123 |
+
dpr = [
|
| 124 |
+
x.item() for x in torch.linspace(0, drop_path_rate, depth)
|
| 125 |
+
]
|
| 126 |
+
|
| 127 |
+
if ffn_layer == "mlp":
|
| 128 |
+
ffn_layer = Mlp
|
| 129 |
+
elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
|
| 130 |
+
ffn_layer = SwiGLUFFNFused
|
| 131 |
+
elif ffn_layer == "identity":
|
| 132 |
+
def f():
|
| 133 |
+
return nn.Identity()
|
| 134 |
+
ffn_layer = f
|
| 135 |
+
else:
|
| 136 |
+
raise NotImplementedError
|
| 137 |
+
|
| 138 |
+
blocks_list = [
|
| 139 |
+
block_fn(
|
| 140 |
+
dim=embed_dim,
|
| 141 |
+
num_heads=num_heads,
|
| 142 |
+
mlp_ratio=mlp_ratio,
|
| 143 |
+
qkv_bias=qkv_bias,
|
| 144 |
+
proj_bias=proj_bias,
|
| 145 |
+
ffn_bias=ffn_bias,
|
| 146 |
+
drop_path=dpr[i],
|
| 147 |
+
norm_layer=nn.LayerNorm,
|
| 148 |
+
act_layer=act_layer,
|
| 149 |
+
ffn_layer=ffn_layer,
|
| 150 |
+
init_values=init_values,
|
| 151 |
+
)
|
| 152 |
+
for i in range(depth)
|
| 153 |
+
]
|
| 154 |
+
self.chunked_blocks = False
|
| 155 |
+
self.blocks = nn.ModuleList(blocks_list)
|
| 156 |
+
|
| 157 |
+
self.norm = nn.LayerNorm(embed_dim)
|
| 158 |
+
self.use_norm = use_norm
|
| 159 |
+
self.head = nn.Identity()
|
| 160 |
+
self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
|
| 161 |
+
|
| 162 |
+
def interpolate_pos_encoding(self, x, W, H):
|
| 163 |
+
previous_dtype = x.dtype
|
| 164 |
+
N = self.pos_embed.shape[1] - 1
|
| 165 |
+
pos_embed = self.pos_embed.float()
|
| 166 |
+
class_pos_embed = pos_embed[:, 0]
|
| 167 |
+
patch_pos_embed = pos_embed[:, 1:]
|
| 168 |
+
dim = x.shape[-1]
|
| 169 |
+
w0 = W // self.patch_size
|
| 170 |
+
h0 = H // self.patch_size
|
| 171 |
+
|
| 172 |
+
M = int(math.sqrt(N))
|
| 173 |
+
assert N == M * M
|
| 174 |
+
kwargs = {}
|
| 175 |
+
kwargs["size"] = (w0, h0)
|
| 176 |
+
|
| 177 |
+
patch_pos_embed = nn.functional.interpolate(
|
| 178 |
+
patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
|
| 179 |
+
mode="bicubic",
|
| 180 |
+
antialias=self.interpolate_antialias,
|
| 181 |
+
**kwargs,
|
| 182 |
+
)
|
| 183 |
+
assert (w0, h0) == patch_pos_embed.shape[-2:]
|
| 184 |
+
|
| 185 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
| 186 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(
|
| 187 |
+
previous_dtype
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
def tokenize(self, x):
|
| 191 |
+
_, _, W, H = x.shape
|
| 192 |
+
with torch.no_grad() if self.frozen_stages > -1 else contextlib.nullcontext():
|
| 193 |
+
x = self.patch_embed(x)
|
| 194 |
+
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
| 195 |
+
dino_pos_embed = self.interpolate_pos_encoding(x, W, H)
|
| 196 |
+
x = x + dino_pos_embed
|
| 197 |
+
return x
|
| 198 |
+
|
| 199 |
+
def forward_features(self, x):
|
| 200 |
+
shapes = [val // self.patch_size for val in x.shape[-2:]]
|
| 201 |
+
batch_size = x.shape[0]
|
| 202 |
+
features = []
|
| 203 |
+
x = self.tokenize(x)
|
| 204 |
+
for i, blk in enumerate(self.blocks):
|
| 205 |
+
with (
|
| 206 |
+
torch.no_grad() if i < self.frozen_stages else contextlib.nullcontext()
|
| 207 |
+
):
|
| 208 |
+
x = blk(x)
|
| 209 |
+
features.append(x)
|
| 210 |
+
if self.use_norm:
|
| 211 |
+
with (
|
| 212 |
+
torch.no_grad()
|
| 213 |
+
if self.frozen_stages >= len(self.blocks)
|
| 214 |
+
else contextlib.nullcontext()
|
| 215 |
+
):
|
| 216 |
+
features = [self.norm(out) for out in features]
|
| 217 |
+
features = [out[:, self.num_register_tokens + 1 :] for out in features]
|
| 218 |
+
features = [out.reshape(batch_size, *shapes, -1) for out in features]
|
| 219 |
+
return features
|
| 220 |
+
|
| 221 |
+
def forward(self, *args):
|
| 222 |
+
features = self.forward_features(*args)
|
| 223 |
+
return features
|
da2/model/dinov2/drop_path.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# References:
|
| 8 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 9 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
| 16 |
+
if drop_prob == 0.0 or not training:
|
| 17 |
+
return x
|
| 18 |
+
keep_prob = 1 - drop_prob
|
| 19 |
+
shape = (x.shape[0],) + (1,) * (
|
| 20 |
+
x.ndim - 1
|
| 21 |
+
) # work with diff dim tensors, not just 2D ConvNets
|
| 22 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
| 23 |
+
if keep_prob > 0.0:
|
| 24 |
+
random_tensor.div_(keep_prob)
|
| 25 |
+
output = x * random_tensor
|
| 26 |
+
return output
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class DropPath(nn.Module):
|
| 30 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
| 31 |
+
|
| 32 |
+
def __init__(self, drop_prob=None):
|
| 33 |
+
super(DropPath, self).__init__()
|
| 34 |
+
self.drop_prob = drop_prob
|
| 35 |
+
|
| 36 |
+
def forward(self, x):
|
| 37 |
+
return drop_path(x, self.drop_prob, self.training)
|
da2/model/dinov2/layer_scale.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
|
| 8 |
+
|
| 9 |
+
from typing import Union
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
from torch import Tensor
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class LayerScale(nn.Module):
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
dim: int,
|
| 20 |
+
init_values: Union[float, Tensor] = 1e-5,
|
| 21 |
+
inplace: bool = False,
|
| 22 |
+
) -> None:
|
| 23 |
+
super().__init__()
|
| 24 |
+
self.inplace = inplace
|
| 25 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
| 26 |
+
|
| 27 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 28 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
da2/model/dinov2/mlp.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# References:
|
| 8 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 9 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
from typing import Callable, Optional
|
| 13 |
+
|
| 14 |
+
from torch import Tensor, nn
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class Mlp(nn.Module):
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
in_features: int,
|
| 21 |
+
hidden_features: Optional[int] = None,
|
| 22 |
+
out_features: Optional[int] = None,
|
| 23 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
| 24 |
+
drop: float = 0.0,
|
| 25 |
+
bias: bool = True,
|
| 26 |
+
) -> None:
|
| 27 |
+
super().__init__()
|
| 28 |
+
out_features = out_features or in_features
|
| 29 |
+
hidden_features = hidden_features or in_features
|
| 30 |
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
|
| 31 |
+
self.act = act_layer()
|
| 32 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
|
| 33 |
+
self.drop = nn.Dropout(drop) if drop > 0.0 else nn.Identity()
|
| 34 |
+
|
| 35 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 36 |
+
x = self.fc1(x)
|
| 37 |
+
x = self.act(x)
|
| 38 |
+
x = self.drop(x)
|
| 39 |
+
x = self.fc2(x)
|
| 40 |
+
x = self.drop(x)
|
| 41 |
+
return x
|
da2/model/dinov2/patch_embed.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# References:
|
| 8 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 9 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
| 10 |
+
|
| 11 |
+
from typing import Callable, Optional, Tuple, Union
|
| 12 |
+
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
from torch import Tensor
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def make_2tuple(x):
|
| 18 |
+
if isinstance(x, tuple):
|
| 19 |
+
assert len(x) == 2
|
| 20 |
+
return x
|
| 21 |
+
|
| 22 |
+
assert isinstance(x, int)
|
| 23 |
+
return (x, x)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class PatchEmbed(nn.Module):
|
| 27 |
+
"""
|
| 28 |
+
2D image to patch embedding: (B,C,H,W) -> (B,N,D)
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
img_size: Image size.
|
| 32 |
+
patch_size: Patch token size.
|
| 33 |
+
in_chans: Number of input image channels.
|
| 34 |
+
embed_dim: Number of linear projection output channels.
|
| 35 |
+
norm_layer: Normalization layer.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def __init__(
|
| 39 |
+
self,
|
| 40 |
+
img_size: Union[int, Tuple[int, int]] = 224,
|
| 41 |
+
patch_size: Union[int, Tuple[int, int]] = 16,
|
| 42 |
+
in_chans: int = 3,
|
| 43 |
+
embed_dim: int = 768,
|
| 44 |
+
norm_layer: Optional[Callable] = None,
|
| 45 |
+
flatten_embedding: bool = True,
|
| 46 |
+
) -> None:
|
| 47 |
+
super().__init__()
|
| 48 |
+
|
| 49 |
+
image_HW = make_2tuple(img_size)
|
| 50 |
+
patch_HW = make_2tuple(patch_size)
|
| 51 |
+
patch_grid_size = (
|
| 52 |
+
image_HW[0] // patch_HW[0],
|
| 53 |
+
image_HW[1] // patch_HW[1],
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
self.img_size = image_HW
|
| 57 |
+
self.patch_size = patch_HW
|
| 58 |
+
self.patches_resolution = patch_grid_size
|
| 59 |
+
self.num_patches = patch_grid_size[0] * patch_grid_size[1]
|
| 60 |
+
|
| 61 |
+
self.in_chans = in_chans
|
| 62 |
+
self.embed_dim = embed_dim
|
| 63 |
+
|
| 64 |
+
self.flatten_embedding = flatten_embedding
|
| 65 |
+
|
| 66 |
+
self.proj = nn.Conv2d(
|
| 67 |
+
in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW
|
| 68 |
+
)
|
| 69 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
| 70 |
+
|
| 71 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 72 |
+
_, _, H, W = x.shape
|
| 73 |
+
patch_H, patch_W = self.patch_size
|
| 74 |
+
|
| 75 |
+
assert (
|
| 76 |
+
H % patch_H == 0
|
| 77 |
+
), f"Input image height {H} is not a multiple of patch height {patch_H}"
|
| 78 |
+
assert (
|
| 79 |
+
W % patch_W == 0
|
| 80 |
+
), f"Input image width {W} is not a multiple of patch width: {patch_W}"
|
| 81 |
+
|
| 82 |
+
x = self.proj(x) # B C H W
|
| 83 |
+
H, W = x.size(2), x.size(3)
|
| 84 |
+
x = x.flatten(2).transpose(1, 2) # B HW C
|
| 85 |
+
x = self.norm(x)
|
| 86 |
+
if not self.flatten_embedding:
|
| 87 |
+
x = x.reshape(-1, H, W, self.embed_dim) # B H W C
|
| 88 |
+
return x
|
| 89 |
+
|
| 90 |
+
def flops(self) -> float:
|
| 91 |
+
Ho, Wo = self.patches_resolution
|
| 92 |
+
flops = (
|
| 93 |
+
Ho
|
| 94 |
+
* Wo
|
| 95 |
+
* self.embed_dim
|
| 96 |
+
* self.in_chans
|
| 97 |
+
* (self.patch_size[0] * self.patch_size[1])
|
| 98 |
+
)
|
| 99 |
+
if self.norm is not None:
|
| 100 |
+
flops += Ho * Wo * self.embed_dim
|
| 101 |
+
return flops
|
da2/model/dinov2/swiglu_ffn.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from typing import Callable, Optional
|
| 8 |
+
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from torch import Tensor, nn
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class SwiGLUFFN(nn.Module):
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
in_features: int,
|
| 17 |
+
hidden_features: Optional[int] = None,
|
| 18 |
+
out_features: Optional[int] = None,
|
| 19 |
+
act_layer: Callable[..., nn.Module] = None,
|
| 20 |
+
drop: float = 0.0,
|
| 21 |
+
bias: bool = True,
|
| 22 |
+
) -> None:
|
| 23 |
+
super().__init__()
|
| 24 |
+
out_features = out_features or in_features
|
| 25 |
+
hidden_features = hidden_features or in_features
|
| 26 |
+
self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
|
| 27 |
+
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
|
| 28 |
+
|
| 29 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 30 |
+
x12 = self.w12(x)
|
| 31 |
+
x1, x2 = x12.chunk(2, dim=-1)
|
| 32 |
+
hidden = F.silu(x1) * x2
|
| 33 |
+
return self.w3(hidden)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
try:
|
| 37 |
+
from xformers.ops import SwiGLU
|
| 38 |
+
|
| 39 |
+
XFORMERS_AVAILABLE = True
|
| 40 |
+
except ImportError:
|
| 41 |
+
SwiGLU = SwiGLUFFN
|
| 42 |
+
XFORMERS_AVAILABLE = False
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class SwiGLUFFNFused(SwiGLU):
|
| 46 |
+
def __init__(
|
| 47 |
+
self,
|
| 48 |
+
in_features: int,
|
| 49 |
+
hidden_features: Optional[int] = None,
|
| 50 |
+
out_features: Optional[int] = None,
|
| 51 |
+
act_layer: Callable[..., nn.Module] = None,
|
| 52 |
+
drop: float = 0.0,
|
| 53 |
+
bias: bool = True,
|
| 54 |
+
) -> None:
|
| 55 |
+
out_features = out_features or in_features
|
| 56 |
+
hidden_features = hidden_features or in_features
|
| 57 |
+
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
|
| 58 |
+
super().__init__(
|
| 59 |
+
in_features=in_features,
|
| 60 |
+
hidden_features=hidden_features,
|
| 61 |
+
out_features=out_features,
|
| 62 |
+
bias=bias,
|
| 63 |
+
)
|
da2/model/sphere.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def get_uv_gird(h, w, device):
|
| 5 |
+
pixel_coords_x = torch.linspace(0.5, w - 0.5, w, device=device)
|
| 6 |
+
pixel_coords_y = torch.linspace(0.5, h - 0.5, h, device=device)
|
| 7 |
+
stacks = [pixel_coords_x.repeat(h, 1), pixel_coords_y.repeat(w, 1).t()]
|
| 8 |
+
grid = torch.stack(stacks, dim=0).float()
|
| 9 |
+
grid = grid.to(device).unsqueeze(0)
|
| 10 |
+
return grid
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class Sphere():
|
| 14 |
+
def __init__(self, config, device):
|
| 15 |
+
self.config = config
|
| 16 |
+
self.device = device
|
| 17 |
+
|
| 18 |
+
def get_directions(self, shape):
|
| 19 |
+
h, w = shape
|
| 20 |
+
uv = get_uv_gird(h, w, device=self.device)
|
| 21 |
+
u, v = uv.unbind(dim=1)
|
| 22 |
+
width, height = self.config['width'], self.config['height']
|
| 23 |
+
hfov, vfov = self.config['hfov'], self.config['vfov']
|
| 24 |
+
longitude = (u - width / 2) / width * hfov
|
| 25 |
+
latitude = (v - height / 2) / height * vfov
|
| 26 |
+
x = torch.cos(latitude) * torch.sin(longitude)
|
| 27 |
+
z = torch.cos(latitude) * torch.cos(longitude)
|
| 28 |
+
y = torch.sin(latitude)
|
| 29 |
+
sphere_directions = torch.stack([x, y, z], dim=1)
|
| 30 |
+
return sphere_directions
|
da2/model/spherevit.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from copy import deepcopy
|
| 5 |
+
from math import (
|
| 6 |
+
ceil,
|
| 7 |
+
sqrt
|
| 8 |
+
)
|
| 9 |
+
from huggingface_hub import PyTorchModelHubMixin
|
| 10 |
+
import torchvision.transforms.v2.functional as TF
|
| 11 |
+
from .dinov2 import DINOViT
|
| 12 |
+
from .vit_w_esphere import ViT_w_Esphere
|
| 13 |
+
from .sphere import Sphere
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
IMAGENET_DATASET_MEAN = (0.485, 0.456, 0.406)
|
| 17 |
+
IMAGENET_DATASET_STD = (0.229, 0.224, 0.225)
|
| 18 |
+
|
| 19 |
+
class SphereViT(nn.Module, PyTorchModelHubMixin):
|
| 20 |
+
def __init__(self, config):
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.config = config
|
| 23 |
+
self.dino = DINOViT()
|
| 24 |
+
self.vit_w_esphere = ViT_w_Esphere(config['spherevit']['vit_w_esphere'])
|
| 25 |
+
feature_slices = self.dino.output_idx
|
| 26 |
+
self.feature_slices = list(
|
| 27 |
+
zip([0, *feature_slices[:-1]], feature_slices)
|
| 28 |
+
)
|
| 29 |
+
self.device = None
|
| 30 |
+
|
| 31 |
+
def to(self, *args):
|
| 32 |
+
self.device = args[0]
|
| 33 |
+
return super().to(*args)
|
| 34 |
+
|
| 35 |
+
def forward(self, images):
|
| 36 |
+
B, _, H, W = images.shape
|
| 37 |
+
current_pixels = H * W
|
| 38 |
+
target_pixels = min(self.config['inference']['max_pixels'],
|
| 39 |
+
max(self.config['inference']['min_pixels'], current_pixels))
|
| 40 |
+
factor = sqrt(target_pixels / current_pixels)
|
| 41 |
+
sphere_config = deepcopy(self.config['spherevit']['sphere'])
|
| 42 |
+
sphere_config['width'] *= factor
|
| 43 |
+
sphere_config['height'] *= factor
|
| 44 |
+
sphere = Sphere(config=sphere_config, device=self.device)
|
| 45 |
+
H_new = int(H * factor)
|
| 46 |
+
W_new = int(W * factor)
|
| 47 |
+
DINO_patch_size = 14 # please see the line 51 of `src/da2/model/dinov2/dinovit.py` (I know it's a little ugly to hardcode it here T_T)
|
| 48 |
+
H_new = ceil(H_new / DINO_patch_size) * DINO_patch_size
|
| 49 |
+
W_new = ceil(W_new / DINO_patch_size) * DINO_patch_size
|
| 50 |
+
images = F.interpolate(images, size=(H_new, W_new), mode='bilinear', align_corners=False)
|
| 51 |
+
images = TF.normalize(
|
| 52 |
+
images.float(),
|
| 53 |
+
mean=IMAGENET_DATASET_MEAN,
|
| 54 |
+
std=IMAGENET_DATASET_STD,
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
sphere_dirs = sphere.get_directions(shape=(H_new, W_new))
|
| 58 |
+
sphere_dirs = sphere_dirs.to(self.device)
|
| 59 |
+
sphere_dirs = sphere_dirs.repeat(B, 1, 1, 1)
|
| 60 |
+
|
| 61 |
+
features = self.dino(images)
|
| 62 |
+
features = [
|
| 63 |
+
features[i:j][-1].contiguous()
|
| 64 |
+
for i, j in self.feature_slices
|
| 65 |
+
]
|
| 66 |
+
distance = self.vit_w_esphere(images, features, sphere_dirs)
|
| 67 |
+
distance = F.interpolate(distance, size=(H, W), mode='bilinear', align_corners=False)
|
| 68 |
+
distance = distance.squeeze(dim=1) # (b, 1, h, w) -> (b, h, w)
|
| 69 |
+
return distance
|
da2/model/vit_w_esphere.py
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from einops import rearrange
|
| 5 |
+
from .base import (
|
| 6 |
+
fourier_dimension_expansion,
|
| 7 |
+
flatten,
|
| 8 |
+
DimensionAligner,
|
| 9 |
+
AttentionSeq,
|
| 10 |
+
ResidualUpsampler
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class _ViT_w_Esphere(nn.Module):
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
hidden_dim: int,
|
| 18 |
+
num_heads: int = 8,
|
| 19 |
+
expansion: int = 4,
|
| 20 |
+
num_layers_head: int | list[int] = 4,
|
| 21 |
+
dropout: float = 0.0,
|
| 22 |
+
kernel_size: int = 7,
|
| 23 |
+
layer_scale: float = 1.0,
|
| 24 |
+
out_dim: int = 1,
|
| 25 |
+
num_prompt_blocks: int = 1,
|
| 26 |
+
use_norm: bool = False,
|
| 27 |
+
**kwargs,
|
| 28 |
+
) -> None:
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.out_dim = out_dim
|
| 31 |
+
self.hidden_dim = hidden_dim
|
| 32 |
+
self.up_sampler = nn.ModuleList([])
|
| 33 |
+
self.pred_head = nn.ModuleList([])
|
| 34 |
+
self.process_features = nn.ModuleList([])
|
| 35 |
+
self.prompt_camera = nn.ModuleList([])
|
| 36 |
+
mult = 2
|
| 37 |
+
self.to_latents = nn.Linear(hidden_dim, hidden_dim)
|
| 38 |
+
|
| 39 |
+
for _ in range(4):
|
| 40 |
+
self.prompt_camera.append(
|
| 41 |
+
AttentionSeq(
|
| 42 |
+
num_blocks=num_prompt_blocks,
|
| 43 |
+
dim=hidden_dim,
|
| 44 |
+
num_heads=num_heads,
|
| 45 |
+
expansion=expansion,
|
| 46 |
+
dropout=dropout,
|
| 47 |
+
layer_scale=-1.0,
|
| 48 |
+
context_dim=hidden_dim,
|
| 49 |
+
)
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
for i, depth in enumerate(num_layers_head):
|
| 53 |
+
current_dim = min(hidden_dim, mult * hidden_dim // int(2**i))
|
| 54 |
+
next_dim = mult * hidden_dim // int(2 ** (i + 1))
|
| 55 |
+
output_dim = max(next_dim, out_dim)
|
| 56 |
+
self.process_features.append(
|
| 57 |
+
nn.ConvTranspose2d(
|
| 58 |
+
hidden_dim,
|
| 59 |
+
current_dim,
|
| 60 |
+
kernel_size=max(1, 2 * i),
|
| 61 |
+
stride=max(1, 2 * i),
|
| 62 |
+
padding=0,
|
| 63 |
+
)
|
| 64 |
+
)
|
| 65 |
+
self.up_sampler.append(
|
| 66 |
+
ResidualUpsampler(
|
| 67 |
+
current_dim,
|
| 68 |
+
output_dim=output_dim,
|
| 69 |
+
expansion=expansion,
|
| 70 |
+
layer_scale=layer_scale,
|
| 71 |
+
kernel_size=kernel_size,
|
| 72 |
+
num_layers=depth,
|
| 73 |
+
use_norm=use_norm,
|
| 74 |
+
)
|
| 75 |
+
)
|
| 76 |
+
pred_head = (
|
| 77 |
+
nn.Sequential(nn.LayerNorm(next_dim), nn.Linear(next_dim, output_dim))
|
| 78 |
+
if i == len(num_layers_head) - 1
|
| 79 |
+
else nn.Identity()
|
| 80 |
+
)
|
| 81 |
+
self.pred_head.append(pred_head)
|
| 82 |
+
|
| 83 |
+
self.to_depth_lr = nn.Conv2d(
|
| 84 |
+
output_dim,
|
| 85 |
+
output_dim // 2,
|
| 86 |
+
kernel_size=3,
|
| 87 |
+
padding=1,
|
| 88 |
+
padding_mode='reflect',
|
| 89 |
+
)
|
| 90 |
+
self.to_confidence_lr = nn.Conv2d(
|
| 91 |
+
output_dim,
|
| 92 |
+
output_dim // 2,
|
| 93 |
+
kernel_size=3,
|
| 94 |
+
padding=1,
|
| 95 |
+
padding_mode='reflect',
|
| 96 |
+
)
|
| 97 |
+
self.to_depth_hr = nn.Sequential(
|
| 98 |
+
nn.Conv2d(
|
| 99 |
+
output_dim // 2, 32, kernel_size=3, padding=1, padding_mode='reflect'
|
| 100 |
+
),
|
| 101 |
+
nn.LeakyReLU(),
|
| 102 |
+
nn.Conv2d(32, 1, kernel_size=1),
|
| 103 |
+
)
|
| 104 |
+
self.to_confidence_hr = nn.Sequential(
|
| 105 |
+
nn.Conv2d(
|
| 106 |
+
output_dim // 2, 32, kernel_size=3, padding=1, padding_mode='reflect'
|
| 107 |
+
),
|
| 108 |
+
nn.LeakyReLU(),
|
| 109 |
+
nn.Conv2d(32, 1, kernel_size=1),
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
def set_original_shapes(self, shapes: tuple[int, int]):
|
| 113 |
+
self.original_shapes = shapes
|
| 114 |
+
|
| 115 |
+
def set_shapes(self, shapes: tuple[int, int]):
|
| 116 |
+
self.shapes = shapes
|
| 117 |
+
|
| 118 |
+
def embed_sphere_dirs(self, sphere_dirs):
|
| 119 |
+
sphere_embedding = flatten(
|
| 120 |
+
sphere_dirs, old=self.original_shapes, new=self.shapes
|
| 121 |
+
)
|
| 122 |
+
# index 0 -> Y
|
| 123 |
+
# index 1 -> Z
|
| 124 |
+
# index 2 -> X
|
| 125 |
+
r1, r2, r3 = sphere_embedding[..., 0], sphere_embedding[..., 1], sphere_embedding[..., 2]
|
| 126 |
+
polar = torch.asin(r2)
|
| 127 |
+
r3_clipped = r3.abs().clip(min=1e-5) * (2 * (r3 >= 0).int() - 1)
|
| 128 |
+
azimuth = torch.atan2(r1, r3_clipped)
|
| 129 |
+
# [polar, azimuth] is the angle field
|
| 130 |
+
sphere_embedding = torch.stack([polar, azimuth], dim=-1)
|
| 131 |
+
# expand the dimension of the angle field to image feature dimensions, via sine-cosine basis embedding
|
| 132 |
+
sphere_embedding = fourier_dimension_expansion(
|
| 133 |
+
sphere_embedding,
|
| 134 |
+
dim=self.hidden_dim,
|
| 135 |
+
max_freq=max(self.shapes) // 2,
|
| 136 |
+
use_cos=False,
|
| 137 |
+
)
|
| 138 |
+
return sphere_embedding
|
| 139 |
+
|
| 140 |
+
def condition(self, feat, sphere_embeddings):
|
| 141 |
+
conditioned_features = [
|
| 142 |
+
prompter(rearrange(feature, 'b h w c -> b (h w) c'), sphere_embeddings)
|
| 143 |
+
for prompter, feature in zip(self.prompt_camera, feat)
|
| 144 |
+
]
|
| 145 |
+
return conditioned_features
|
| 146 |
+
|
| 147 |
+
def process(self, features_list, sphere_embeddings):
|
| 148 |
+
conditioned_features = self.condition(features_list, sphere_embeddings)
|
| 149 |
+
init_latents = self.to_latents(conditioned_features[0])
|
| 150 |
+
init_latents = rearrange(
|
| 151 |
+
init_latents, 'b (h w) c -> b c h w', h=self.shapes[0], w=self.shapes[1]
|
| 152 |
+
).contiguous()
|
| 153 |
+
conditioned_features = [
|
| 154 |
+
rearrange(
|
| 155 |
+
x, 'b (h w) c -> b c h w', h=self.shapes[0], w=self.shapes[1]
|
| 156 |
+
).contiguous()
|
| 157 |
+
for x in conditioned_features
|
| 158 |
+
]
|
| 159 |
+
latents = init_latents
|
| 160 |
+
|
| 161 |
+
out_features = []
|
| 162 |
+
# Pyramid-like multi-layer convolutional feature extraction
|
| 163 |
+
for i, up in enumerate(self.up_sampler):
|
| 164 |
+
latents = latents + self.process_features[i](conditioned_features[i + 1])
|
| 165 |
+
latents = up(latents)
|
| 166 |
+
out_features.append(latents)
|
| 167 |
+
return out_features
|
| 168 |
+
|
| 169 |
+
def prediction_head(self, out_features):
|
| 170 |
+
depths = []
|
| 171 |
+
h_out, w_out = out_features[-1].shape[-2:]
|
| 172 |
+
for i, (layer, features) in enumerate(zip(self.pred_head, out_features)):
|
| 173 |
+
out_depth_features = layer(features.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
| 174 |
+
if i < len(self.pred_head) - 1:
|
| 175 |
+
continue
|
| 176 |
+
depths.append(out_depth_features)
|
| 177 |
+
out_depth_features = F.interpolate(
|
| 178 |
+
out_depth_features, size=(h_out, w_out), mode='bilinear', align_corners=True
|
| 179 |
+
)
|
| 180 |
+
distance = self.to_depth_lr(out_depth_features)
|
| 181 |
+
distance = F.interpolate(
|
| 182 |
+
distance, size=self.original_shapes, mode='bilinear', align_corners=True
|
| 183 |
+
)
|
| 184 |
+
distance = self.to_depth_hr(distance)
|
| 185 |
+
return distance
|
| 186 |
+
|
| 187 |
+
def forward(
|
| 188 |
+
self,
|
| 189 |
+
features: list[torch.Tensor],
|
| 190 |
+
sphere_dirs: torch.Tensor
|
| 191 |
+
) -> torch.Tensor:
|
| 192 |
+
sphere_embeddings = self.embed_sphere_dirs(sphere_dirs)
|
| 193 |
+
features = self.process(features, sphere_embeddings)
|
| 194 |
+
distance = self.prediction_head(features)
|
| 195 |
+
return distance
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
class ViT_w_Esphere(nn.Module):
|
| 199 |
+
def __init__(self, config):
|
| 200 |
+
super().__init__()
|
| 201 |
+
self.config = config
|
| 202 |
+
self.dim_aligner = DimensionAligner(
|
| 203 |
+
input_dims=config['input_dims'],
|
| 204 |
+
hidden_dim=config['hidden_dim'],
|
| 205 |
+
)
|
| 206 |
+
self._vit_w_esphere = _ViT_w_Esphere(**config)
|
| 207 |
+
|
| 208 |
+
def forward(self, images, features, sphere_dirs) -> torch.Tensor:
|
| 209 |
+
_, _, H, W = images.shape
|
| 210 |
+
sphere_dirs = sphere_dirs
|
| 211 |
+
common_shape = features[0].shape[1:3]
|
| 212 |
+
features = self.dim_aligner(features)
|
| 213 |
+
sphere_dirs = rearrange(sphere_dirs, 'b c h w -> b (h w) c')
|
| 214 |
+
|
| 215 |
+
self._vit_w_esphere.set_shapes(common_shape)
|
| 216 |
+
self._vit_w_esphere.set_original_shapes((H, W))
|
| 217 |
+
logdistance = self._vit_w_esphere(
|
| 218 |
+
features=features,
|
| 219 |
+
sphere_dirs=sphere_dirs,
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
distance = torch.exp(logdistance.clip(min=-8.0, max=8.0) + 2.0)
|
| 223 |
+
distance = distance / torch.quantile(distance, 0.98)
|
| 224 |
+
return distance
|
da2/utils/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .base import (
|
| 2 |
+
prepare_to_run
|
| 3 |
+
)
|
| 4 |
+
from .model import (
|
| 5 |
+
load_model
|
| 6 |
+
)
|
| 7 |
+
|
| 8 |
+
__all__ = [
|
| 9 |
+
'prepare_to_run',
|
| 10 |
+
'load_model'
|
| 11 |
+
]
|
da2/utils/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (287 Bytes). View file
|
|
|
da2/utils/__pycache__/base.cpython-312.pyc
ADDED
|
Binary file (3.22 kB). View file
|
|
|
da2/utils/__pycache__/d2pc.cpython-312.pyc
ADDED
|
Binary file (6.21 kB). View file
|
|
|
da2/utils/__pycache__/io.cpython-312.pyc
ADDED
|
Binary file (3.62 kB). View file
|
|
|
da2/utils/__pycache__/model.cpython-312.pyc
ADDED
|
Binary file (1.23 kB). View file
|
|
|
da2/utils/__pycache__/vis.cpython-312.pyc
ADDED
|
Binary file (2.42 kB). View file
|
|
|
da2/utils/base.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import argparse
|
| 3 |
+
import os
|
| 4 |
+
from accelerate import Accelerator
|
| 5 |
+
from accelerate.logging import get_logger
|
| 6 |
+
from accelerate.utils import (
|
| 7 |
+
InitProcessGroupKwargs,
|
| 8 |
+
ProjectConfiguration,
|
| 9 |
+
set_seed
|
| 10 |
+
)
|
| 11 |
+
import logging
|
| 12 |
+
from datetime import (
|
| 13 |
+
timedelta,
|
| 14 |
+
datetime
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def load_config(config_path):
|
| 19 |
+
with open(config_path, 'r') as f:
|
| 20 |
+
config = json.load(f)
|
| 21 |
+
return config
|
| 22 |
+
|
| 23 |
+
def arg_parser():
|
| 24 |
+
parser = argparse.ArgumentParser()
|
| 25 |
+
parser.add_argument('--config_path', type=str, required=True)
|
| 26 |
+
args = parser.parse_args()
|
| 27 |
+
return args
|
| 28 |
+
|
| 29 |
+
def prepare_to_run():
|
| 30 |
+
args = arg_parser()
|
| 31 |
+
logging.basicConfig(
|
| 32 |
+
format='%(asctime)s --> %(message)s',
|
| 33 |
+
datefmt='%m/%d %H:%M:%S',
|
| 34 |
+
level=logging.INFO,
|
| 35 |
+
)
|
| 36 |
+
config = load_config(args.config_path)
|
| 37 |
+
kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=config['accelerator']['timeout']))
|
| 38 |
+
version = os.path.basename(args.config_path)[:-5]
|
| 39 |
+
output_dir = f'output/{version}_{datetime.now().strftime("%Y%m%d_%H%M%S")}'
|
| 40 |
+
if not os.path.exists(output_dir): os.makedirs(output_dir, exist_ok=True)
|
| 41 |
+
accu_steps = config['accelerator']['accumulation_nsteps']
|
| 42 |
+
accelerator = Accelerator(
|
| 43 |
+
gradient_accumulation_steps=accu_steps,
|
| 44 |
+
mixed_precision=config['accelerator']['mixed_precision'],
|
| 45 |
+
log_with=config['accelerator']['report_to'],
|
| 46 |
+
project_config=ProjectConfiguration(project_dir=output_dir),
|
| 47 |
+
kwargs_handlers=[kwargs]
|
| 48 |
+
)
|
| 49 |
+
logger = get_logger(__name__, log_level='INFO')
|
| 50 |
+
config['env']['logger'] = logger
|
| 51 |
+
set_seed(config['env']['seed'])
|
| 52 |
+
if config['env']['verbose']:
|
| 53 |
+
logger.info(f'Version: {version} (from {args.config_path})')
|
| 54 |
+
logger.info(f'Output dir: {output_dir}')
|
| 55 |
+
logger.info(f'Using {accelerator.num_processes} GPU' + ('s' if accelerator.num_processes > 1 else ''))
|
| 56 |
+
return config, accelerator, output_dir
|
da2/utils/d2pc.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
import utils3d
|
| 4 |
+
from plyfile import PlyData, PlyElement
|
| 5 |
+
from PIL import Image
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def sphere_uv2dirs(uv: np.ndarray):
|
| 9 |
+
theta, phi = (1 - uv[..., 0]) * (2 * np.pi), uv[..., 1] * np.pi
|
| 10 |
+
directions = np.stack([np.sin(phi) * np.cos(theta), np.sin(phi) * np.sin(theta), np.cos(phi)], axis=-1)
|
| 11 |
+
return directions
|
| 12 |
+
|
| 13 |
+
def save_3d_points(points: np.array, colors: np.array, mask: np.array, save_path: str):
|
| 14 |
+
points = points.reshape(-1, 3)
|
| 15 |
+
colors = colors.reshape(-1, 3)
|
| 16 |
+
mask = mask.reshape(-1, 1)
|
| 17 |
+
|
| 18 |
+
vertex_data = np.empty(mask.sum(), dtype=[
|
| 19 |
+
('x', 'f4'),
|
| 20 |
+
('y', 'f4'),
|
| 21 |
+
('z', 'f4'),
|
| 22 |
+
('red', 'u1'),
|
| 23 |
+
('green', 'u1'),
|
| 24 |
+
('blue', 'u1'),
|
| 25 |
+
])
|
| 26 |
+
vertex_data['x'] = [a for i, a in enumerate(points[:, 0]) if mask[i]]
|
| 27 |
+
vertex_data['y'] = [a for i, a in enumerate(points[:, 1]) if mask[i]]
|
| 28 |
+
vertex_data['z'] = [a for i, a in enumerate(points[:, 2]) if mask[i]]
|
| 29 |
+
vertex_data['red'] = [a for i, a in enumerate(colors[:, 0]) if mask[i]]
|
| 30 |
+
vertex_data['green'] = [a for i, a in enumerate(colors[:, 1]) if mask[i]]
|
| 31 |
+
vertex_data['blue'] = [a for i, a in enumerate(colors[:, 2]) if mask[i]]
|
| 32 |
+
|
| 33 |
+
vertex_element = PlyElement.describe(vertex_data, 'vertex', comments=['vertices with color'])
|
| 34 |
+
save_dir = os.path.dirname(save_path)
|
| 35 |
+
if not os.path.exists(save_dir): os.makedirs(save_dir, exist_ok=True)
|
| 36 |
+
PlyData([vertex_element], text=True).write(save_path)
|
| 37 |
+
|
| 38 |
+
def colorize_normal(normal: np.ndarray, normal_mask: np.ndarray):
|
| 39 |
+
normal_rgb = (((normal + 1) * 0.5) * 255).astype(np.uint8)
|
| 40 |
+
normal_mask = np.repeat(np.expand_dims(normal_mask, axis=-1), 3, axis=-1)
|
| 41 |
+
normal_mask = normal_mask.astype(np.uint8)
|
| 42 |
+
normal_rgb = normal_rgb * normal_mask
|
| 43 |
+
return normal_rgb
|
| 44 |
+
|
| 45 |
+
def normal_normalize(normal: np.ndarray):
|
| 46 |
+
normal_norm = np.linalg.norm(normal, axis=-1, keepdims=True)
|
| 47 |
+
normal_norm[normal_norm < 1e-6] = 1e-6
|
| 48 |
+
return normal / normal_norm
|
| 49 |
+
|
| 50 |
+
def distance2pointcloud(
|
| 51 |
+
distance: np.ndarray,
|
| 52 |
+
image: np.ndarray,
|
| 53 |
+
mask: np.ndarray,
|
| 54 |
+
save_path: str = None,
|
| 55 |
+
return_normal: bool = False,
|
| 56 |
+
save_distance: bool = False
|
| 57 |
+
):
|
| 58 |
+
if distance.ndim >= 3: distance = distance.squeeze()
|
| 59 |
+
if save_distance:
|
| 60 |
+
save_path_dis = save_path.replace('3dpc', 'depth').replace('.ply', '.npy')
|
| 61 |
+
save_dir_dis = os.path.dirname(save_path_dis)
|
| 62 |
+
if not os.path.exists(save_dir_dis): os.makedirs(save_dir_dis, exist_ok=True)
|
| 63 |
+
np.save(save_path_dis, distance)
|
| 64 |
+
height, width = distance.shape[:2]
|
| 65 |
+
points = distance[:, :, None] * sphere_uv2dirs(utils3d.numpy.image_uv(width=width, height=height))
|
| 66 |
+
save_3d_points(points, image, mask, save_path)
|
| 67 |
+
if return_normal:
|
| 68 |
+
normal, normal_mask = utils3d.numpy.points_to_normals(points, mask)
|
| 69 |
+
normal = normal * np.array([-1, -1, 1])
|
| 70 |
+
normal = normal_normalize(normal)
|
| 71 |
+
normal_1 = normal[..., 0]
|
| 72 |
+
normal_2 = normal[..., 1]
|
| 73 |
+
normal_3 = normal[..., 2]
|
| 74 |
+
normal = np.stack([normal_1, normal_3, normal_2], axis=-1)
|
| 75 |
+
normal_img = colorize_normal(normal, normal_mask)
|
| 76 |
+
return Image.fromarray(normal_img)
|
da2/utils/io.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import cv2
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
from glob import glob
|
| 6 |
+
from PIL import Image
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def torch_transform(image):
|
| 10 |
+
image = image / 255.0
|
| 11 |
+
image = np.transpose(image, (2, 0, 1))
|
| 12 |
+
return image
|
| 13 |
+
|
| 14 |
+
def read_cv2_image(image_path):
|
| 15 |
+
image = cv2.imread(image_path)
|
| 16 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 17 |
+
return image
|
| 18 |
+
|
| 19 |
+
def read_mask(mask_path, shape):
|
| 20 |
+
if not os.path.exists(mask_path):
|
| 21 |
+
return np.ones(shape[1:]) > 0
|
| 22 |
+
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
|
| 23 |
+
mask = mask > 0
|
| 24 |
+
return mask
|
| 25 |
+
|
| 26 |
+
def tensorize(array, model_dtype, device):
|
| 27 |
+
array = torch.from_numpy(array).to(device).to(model_dtype).unsqueeze(dim=0)
|
| 28 |
+
return array
|
| 29 |
+
|
| 30 |
+
def load_infer_data(config, device):
|
| 31 |
+
image_dir = config['inference']['images']
|
| 32 |
+
mask_dir = config['inference']['masks']
|
| 33 |
+
|
| 34 |
+
image_paths = glob(os.path.join(image_dir, '*.png'))
|
| 35 |
+
image_paths = sorted(image_paths)
|
| 36 |
+
filenames = [os.path.basename(image_path)[:-4] for image_path in image_paths]
|
| 37 |
+
cv2_images = [read_cv2_image(image_path)
|
| 38 |
+
for image_path in image_paths]
|
| 39 |
+
PIL_images = [Image.fromarray(cv2_image) for cv2_image in cv2_images]
|
| 40 |
+
images = [torch_transform(cv2_image) for cv2_image in cv2_images]
|
| 41 |
+
|
| 42 |
+
mask_paths = [image_path.replace(image_dir, mask_dir)
|
| 43 |
+
for image_path in image_paths]
|
| 44 |
+
masks = [read_mask(mask_path, images[i].shape)
|
| 45 |
+
for (i, mask_path) in enumerate(mask_paths)]
|
| 46 |
+
|
| 47 |
+
model_dtype = config['spherevit']['dtype']
|
| 48 |
+
images = [tensorize(image, model_dtype, device) for image in images]
|
| 49 |
+
|
| 50 |
+
infer_data = {
|
| 51 |
+
'images': {
|
| 52 |
+
'PIL': PIL_images,
|
| 53 |
+
'cv2': cv2_images,
|
| 54 |
+
'torch': images
|
| 55 |
+
},
|
| 56 |
+
'masks': masks,
|
| 57 |
+
'filenames': filenames,
|
| 58 |
+
'size': len(images)
|
| 59 |
+
}
|
| 60 |
+
if config['env']['verbose']:
|
| 61 |
+
s = 's' if len(images) > 1 else ''
|
| 62 |
+
config['env']['logger'].info(f'Loaded {len(images)} image{s} in {model_dtype}')
|
| 63 |
+
return infer_data
|
da2/utils/model.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from ..model.spherevit import SphereViT
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def load_model(config, accelerator):
|
| 6 |
+
model = SphereViT.from_pretrained('haodongli/DA-2', config=config)
|
| 7 |
+
model = model.to(accelerator.device)
|
| 8 |
+
torch.cuda.empty_cache()
|
| 9 |
+
model = accelerator.prepare(model)
|
| 10 |
+
if accelerator.num_processes > 1:
|
| 11 |
+
model = model.module
|
| 12 |
+
if config['env']['verbose']:
|
| 13 |
+
config['env']['logger'].info(f'Model\'s dtype: {next(model.parameters()).dtype}.')
|
| 14 |
+
config['spherevit']['dtype'] = next(model.parameters()).dtype
|
| 15 |
+
return model
|
da2/utils/vis.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from PIL import Image
|
| 3 |
+
import numpy as np
|
| 4 |
+
import matplotlib
|
| 5 |
+
import cv2
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def concatenate_images(*image_lists):
|
| 9 |
+
max_width = 0
|
| 10 |
+
total_height = 0
|
| 11 |
+
row_widths = []
|
| 12 |
+
row_heights = []
|
| 13 |
+
|
| 14 |
+
for i, image_list in enumerate(image_lists):
|
| 15 |
+
width = sum(img.width for img in image_list)
|
| 16 |
+
max_width = max(max_width, width)
|
| 17 |
+
row_widths.append(width)
|
| 18 |
+
# Assuming all images in the list have the same height
|
| 19 |
+
height = image_list[0].height
|
| 20 |
+
total_height += height
|
| 21 |
+
row_heights.append(height)
|
| 22 |
+
|
| 23 |
+
new_image = Image.new('RGB', (max_width, total_height))
|
| 24 |
+
y_offset = 0
|
| 25 |
+
for i, image_list in enumerate(image_lists):
|
| 26 |
+
x_offset = 0
|
| 27 |
+
for img in image_list:
|
| 28 |
+
new_image.paste(img, (x_offset, y_offset))
|
| 29 |
+
x_offset += img.width
|
| 30 |
+
y_offset += row_heights[i]
|
| 31 |
+
return new_image
|
| 32 |
+
|
| 33 |
+
def colorize_distance(distance, mask, cmap='Spectral'):
|
| 34 |
+
if distance.ndim >= 3: distance = distance.squeeze()
|
| 35 |
+
cm = matplotlib.colormaps[cmap]
|
| 36 |
+
valid_distance = distance[mask]
|
| 37 |
+
max_distance = np.quantile(valid_distance, 0.98)
|
| 38 |
+
min_distance = np.quantile(valid_distance, 0.02)
|
| 39 |
+
distance[~mask] = max_distance
|
| 40 |
+
distance = ((distance - min_distance) / (max_distance - min_distance))
|
| 41 |
+
distance = np.clip(distance, 0, 1)
|
| 42 |
+
img_colored_np = cm(distance, bytes=False)[:, :, 0:3]
|
| 43 |
+
distance_colored = (img_colored_np * 255).astype(np.uint8)
|
| 44 |
+
return Image.fromarray(distance_colored)
|
requirements.txt
CHANGED
|
@@ -1 +1,19 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch==2.5.0
|
| 2 |
+
torchvision==0.20.0
|
| 3 |
+
torchaudio==2.5.0
|
| 4 |
+
xformers==0.0.28.post2
|
| 5 |
+
diffusers==0.32.0
|
| 6 |
+
tensorboard==2.18.0
|
| 7 |
+
git+https://github.com/EasternJournalist/utils3d.git@3913c65d81e05e47b9f367250cf8c0f7462a0900
|
| 8 |
+
opencv-python==4.12.0.88
|
| 9 |
+
gradio==5.49.0
|
| 10 |
+
gradio-client==1.13.3
|
| 11 |
+
gradio-imageslider==0.0.20
|
| 12 |
+
accelerate==1.1.1
|
| 13 |
+
omegaconf==2.3.0
|
| 14 |
+
tabulate==0.9.0
|
| 15 |
+
einops==0.8.0
|
| 16 |
+
timm==1.0.15
|
| 17 |
+
trimesh==4.5.2
|
| 18 |
+
transformers==4.46.3
|
| 19 |
+
matplotlib==3.9.2
|