PuTorch commited on
Commit
6ffba01
·
verified ·
1 Parent(s): 618c461

upload CondRef-AR model

Browse files
.gitattributes CHANGED
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/aerial_img.gif filter=lfs diff=lfs merge=lfs -text
37
+ assets/control_img.gif filter=lfs diff=lfs merge=lfs -text
38
+ assets/evolution.png filter=lfs diff=lfs merge=lfs -text
39
+ assets/method.jpg filter=lfs diff=lfs merge=lfs -text
40
+ assets/samples.png filter=lfs diff=lfs merge=lfs -text
CondRefAR/models/dinov2_adapter.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoImageProcessor, AutoModel
2
+ from PIL import Image
3
+ import requests
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+
8
+ class Dinov2_Adapter(nn.Module):
9
+ def __init__(self, input_dim=1, output_dim=768, attention=False, pool=False, nheads=8, dropout=0.1, adapter_size='small', condition_type='canny'):
10
+ super(Dinov2_Adapter, self).__init__()
11
+ print(f"Choose adapter size: {adapter_size}")
12
+ print(f"condition type: {condition_type}")
13
+ self.model = AutoModel.from_pretrained('D:\\Alps\\Aerial\\Code\\ControlRAR\\checkpoints\\dinov2\\')
14
+ self.condition_type = condition_type
15
+
16
+ def to_patch14(self, input):
17
+ H, W = input.shape[2:]
18
+ new_H = (H // 16) * 14
19
+ new_W = (W // 16) * 14
20
+ if self.condition_type in ['canny', 'seg']:
21
+ output = torch.nn.functional.interpolate(input, size=(new_H, new_W), mode='nearest')#, align_corners=True) canny, seg
22
+ else:
23
+ output = torch.nn.functional.interpolate(input, size=(new_H, new_W), mode='bicubic', align_corners=True) # depth, lineart, hed
24
+ return output
25
+
26
+ def forward(self, x):
27
+ x = self.to_patch14(x)
28
+ x = self.model(x)
29
+ return x.last_hidden_state[:, 1:]
30
+
31
+
32
+ if __name__ == '__main__':
33
+ model = Dinov2_Adapter().cuda()
34
+ inputs = torch.randn(4,3,512,512).cuda()
35
+ outputs = model(inputs)
36
+ print(outputs.shape)
CondRefAR/models/generate.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from:
2
+ # gpt-fast: https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py
3
+ # DiT: https://github.com/facebookresearch/DiT/blob/main/models.py
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch.nn import functional as F
7
+ import torch._dynamo.config
8
+ import torch._inductor.config
9
+ import copy
10
+ import time
11
+ import pdb
12
+ # torch._inductor.config.coordinate_descent_tuning = True
13
+ # torch._inductor.config.triton.unique_kernel_names = True
14
+ # torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
15
+
16
+
17
+ ### from https://huggingface.co/transformers/v3.2.0/_modules/transformers/generation_utils.html
18
+ def top_k_top_p_filtering(
19
+ logits,
20
+ top_k: int = 0,
21
+ top_p: float = 1.0,
22
+ filter_value: float = -float("Inf"),
23
+ min_tokens_to_keep: int = 1,
24
+ ):
25
+ """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
26
+ Args:
27
+ logits: logits distribution shape (batch size, vocabulary size)
28
+ if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
29
+ if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
30
+ Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
31
+ Make sure we keep at least min_tokens_to_keep per batch example in the output
32
+ From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
33
+ """
34
+ if top_k > 0:
35
+ # import pdb;pdb.set_trace()
36
+ top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
37
+ # Remove all tokens with a probability less than the last token of the top-k
38
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
39
+ logits[indices_to_remove] = filter_value
40
+
41
+ if top_p < 1.0:
42
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
43
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
44
+
45
+ # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
46
+ sorted_indices_to_remove = cumulative_probs > top_p
47
+ if min_tokens_to_keep > 1:
48
+ # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
49
+ sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
50
+ # Shift the indices to the right to keep also the first token above the threshold
51
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
52
+ sorted_indices_to_remove[..., 0] = 0
53
+
54
+ # scatter sorted tensors to original indexing
55
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
56
+ logits[indices_to_remove] = filter_value
57
+ return logits
58
+
59
+
60
+ def sample(logits, temperature: float=1.0, top_k: int=2000, top_p: float=1.0, sample_logits=True):
61
+ logits = logits[:, -1, :] / max(temperature, 1e-5)
62
+ if top_k > 0 or top_p < 1.0:
63
+ logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
64
+ probs = F.softmax(logits, dim=-1)
65
+ # values, indices = torch.max(probs, dim=1, keepdim=True)
66
+ # mask = (probs == values).float()
67
+ # probs = probs * (1 - mask)
68
+ # values, indices = torch.max(probs, dim=1, keepdim=True)
69
+ # mask = (probs == values).float()
70
+ # probs = probs * (1 - mask)
71
+ if sample_logits:
72
+ idx = torch.multinomial(probs, num_samples=1)
73
+ else:
74
+ _, idx = torch.topk(probs, k=1, dim=-1)
75
+ return idx, probs
76
+
77
+
78
+ def logits_to_probs(logits, temperature: float = 1.0, top_p: float=1.0, top_k: int = None, **kwargs):
79
+ logits = logits / max(temperature, 1e-5)
80
+ if top_k > 0 or top_p < 1.0:
81
+ logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
82
+ probs = torch.nn.functional.softmax(logits, dim=-1)
83
+ return probs
84
+
85
+
86
+ def prefill(model, cond_idx: torch.Tensor, input_pos: torch.Tensor, cfg_scale: float, condition:torch.Tensor, control_strength: float=1, **sampling_kwargs):
87
+ if cfg_scale > 1.0:
88
+ logits, _ = model(None, cond_idx, input_pos, condition=condition, control_strength=control_strength)
89
+ logits_combined = logits
90
+ cond_logits, uncond_logits = torch.split(logits_combined, len(logits_combined) // 2, dim=0)
91
+ logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale
92
+ else:
93
+ logits, _ = model(None, cond_idx, input_pos, condition=condition)
94
+
95
+ return sample(logits, **sampling_kwargs)[0]
96
+
97
+
98
+ def decode_one_token(model, x: torch.Tensor, input_pos: torch.Tensor, cfg_scale: float, cfg_flag: bool, condition: torch.Tensor, **sampling_kwargs):
99
+ assert input_pos.shape[-1] == 1
100
+ if cfg_scale > 1.0:
101
+ x_combined = torch.cat([x, x])
102
+ logits, _ = model(x_combined, cond_idx=None, input_pos=input_pos, condition=condition)
103
+ logits_combined = logits
104
+ cond_logits, uncond_logits = torch.split(logits_combined, len(logits_combined) // 2, dim=0)
105
+ if cfg_flag:
106
+ logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale
107
+ else:
108
+ logits = cond_logits
109
+ else:
110
+ logits, _ = model(x, cond_idx=None, input_pos=input_pos, condition=None)
111
+ return sample(logits, **sampling_kwargs)
112
+
113
+
114
+ def decode_n_tokens(
115
+ model, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int,
116
+ cfg_scale: float, cfg_interval: int, condition: torch.Tensor,
117
+ **sampling_kwargs):
118
+ new_tokens, new_probs = [], []
119
+ cfg_flag = True
120
+ for i in range(num_new_tokens):
121
+ with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): # Actually better for Inductor to codegen attention here
122
+ if cfg_interval > -1 and i > cfg_interval:
123
+ cfg_flag = False
124
+ next_token, next_prob = decode_one_token(
125
+ model, cur_token, input_pos, cfg_scale, cfg_flag, condition=condition, **sampling_kwargs
126
+ )
127
+ input_pos += 1
128
+ new_tokens.append(next_token.clone())
129
+ new_probs.append(next_prob.clone())
130
+ cur_token = next_token.view(-1, 1)
131
+
132
+ return new_tokens, new_probs
133
+
134
+
135
+ @torch.no_grad()
136
+ def generate(model, cond, max_new_tokens, emb_masks=None, cfg_scale=1.0, cfg_interval=-1, condition=None, condition_null=None, condition_token_nums=0, control_strength=1, **sampling_kwargs):
137
+ if condition is not None:
138
+ condition = model.adapter(condition)
139
+ condition = model.adapter_mlp(condition)
140
+ if model.model_type == 'c2i':
141
+ if cfg_scale > 1.0:
142
+ cond_null = torch.ones_like(cond) * model.num_classes
143
+ cond_combined = torch.cat([cond, cond_null])
144
+ if condition is not None:
145
+ condition_null = torch.zeros_like(condition)
146
+ condition_combined = torch.cat((condition, condition_null), dim=0)
147
+ else:
148
+ condition_combined = None
149
+ else:
150
+ cond_combined = cond
151
+ if condition is not None:
152
+ condition_combined = condition
153
+ else:
154
+ condition_combined = None
155
+ T = 1+condition_token_nums
156
+ elif model.model_type == 't2i':
157
+ if cfg_scale > 1.0:
158
+ cond_null = torch.zeros_like(cond) + model.cls_embedding.uncond_embedding
159
+ cond_combined = torch.cat([cond, cond_null])
160
+
161
+ if condition is not None:
162
+ condition_null = torch.zeros_like(condition)
163
+ condition_combined = torch.cat((condition, condition_null), dim=0)
164
+ else:
165
+ condition_combined = None
166
+ else:
167
+ cond_combined = cond
168
+ if condition is not None:
169
+ condition_combined = condition
170
+ else:
171
+ condition_combined = None
172
+ T = cond.shape[1]
173
+ else:
174
+ raise Exception("please check model type")
175
+
176
+ T_new = T + max_new_tokens
177
+ max_seq_length = T_new
178
+ max_batch_size = cond.shape[0]
179
+
180
+ device = cond.device
181
+ with torch.device(device):
182
+ max_batch_size_cfg = max_batch_size * 2 if cfg_scale > 1.0 else max_batch_size
183
+ model.setup_caches(max_batch_size=max_batch_size_cfg, max_seq_length=max_seq_length, dtype=model.tok_embeddings.weight.dtype)
184
+
185
+ if emb_masks is not None:
186
+ assert emb_masks.shape[0] == max_batch_size
187
+ assert emb_masks.shape[-1] == T
188
+ if cfg_scale > 1.0:
189
+ model.causal_mask[:, :, :T] = model.causal_mask[:, :, :T] * torch.cat([emb_masks, emb_masks]).unsqueeze(1)
190
+ else:
191
+ model.causal_mask[:, :, :T] = model.causal_mask[:, :, :T] * emb_masks.unsqueeze(1)
192
+
193
+ eye_matrix = torch.eye(model.causal_mask.size(1), model.causal_mask.size(2), device=device)
194
+ model.causal_mask[:] = model.causal_mask * (1 - eye_matrix) + eye_matrix
195
+
196
+ # create an empty tensor of the expected final shape and fill in the current tokens
197
+ seq = torch.empty((max_batch_size, T_new), dtype=torch.int, device=device)
198
+ input_pos = torch.arange(0, T, device=device)
199
+ next_token = prefill(model, cond_combined, input_pos, cfg_scale, condition_combined, control_strength, **sampling_kwargs)
200
+ seq[:, T:T+1] = next_token
201
+
202
+ input_pos = torch.tensor([T], device=device, dtype=torch.int)
203
+ generated_tokens, _ = decode_n_tokens(model, next_token, input_pos, max_new_tokens-1, cfg_scale, cfg_interval, condition=condition_combined, **sampling_kwargs)
204
+ seq[:, T+1:] = torch.cat(generated_tokens, dim=1)
205
+ return seq[:, T:]
CondRefAR/models/gpt_t2i.py ADDED
@@ -0,0 +1,588 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from:
2
+ # VQGAN: https://github.com/CompVis/taming-transformers/blob/master/taming/modules/transformer/mingpt.py
3
+ # DiT: https://github.com/facebookresearch/DiT/blob/main/models.py
4
+ # nanoGPT: https://github.com/karpathy/nanoGPT/blob/master/model.py
5
+ # llama: https://github.com/facebookresearch/llama/blob/main/llama/model.py
6
+ # gpt-fast: https://github.com/pytorch-labs/gpt-fast/blob/main/model.py
7
+ # PixArt: https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
8
+ from dataclasses import dataclass
9
+ from typing import Optional, List
10
+
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ from torch.nn import functional as F
15
+ from ..utils.drop_path import DropPath
16
+ from .dinov2_adapter import Dinov2_Adapter
17
+
18
+ def get_causal_mask(seq_length):
19
+ mask = torch.triu(torch.ones(seq_length, seq_length), diagonal=1).type(torch.bool)
20
+ mask = mask.masked_fill(mask, float('-inf'))
21
+ mask = mask.masked_fill(~mask, float(0.0))
22
+ return mask
23
+
24
+ def find_multiple(n: int, k: int):
25
+ if n % k == 0:
26
+ return n
27
+ return n + k - (n % k)
28
+
29
+ @dataclass
30
+ class ModelArgs:
31
+ dim: int = 4096
32
+ n_layer: int = 32
33
+ n_head: int = 32
34
+ n_kv_head: Optional[int] = None
35
+ multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
36
+ ffn_dim_multiplier: Optional[float] = None
37
+ rope_base: float = 10000
38
+ norm_eps: float = 1e-5
39
+ initializer_range: float = 0.02
40
+
41
+ token_dropout_p: float = 0.1
42
+ attn_dropout_p: float = 0.0
43
+ resid_dropout_p: float = 0.1
44
+ ffn_dropout_p: float = 0.1
45
+ drop_path_rate: float = 0.0
46
+
47
+ num_classes: int = 1000
48
+ caption_dim: int = 2048
49
+ class_dropout_prob: float = 0.1
50
+ model_type: str = 'c2i'
51
+
52
+ vocab_size: int = 16384
53
+ cls_token_num: int = 1
54
+ block_size: int = 256
55
+ max_batch_size: int = 32
56
+ max_seq_len: int = 2048
57
+ adapter_size: str = 'small'
58
+ condition_type: str = 'canny'
59
+
60
+
61
+
62
+ #################################################################################
63
+ # Embedding Layers for Class Labels #
64
+ #################################################################################
65
+ class LabelEmbedder(nn.Module):
66
+ """
67
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
68
+ """
69
+ def __init__(self, num_classes, hidden_size, dropout_prob):
70
+ super().__init__()
71
+ use_cfg_embedding = dropout_prob > 0
72
+ self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
73
+ self.num_classes = num_classes
74
+ self.dropout_prob = dropout_prob
75
+
76
+ def token_drop(self, labels, force_drop_ids=None):
77
+ """
78
+ Drops labels to enable classifier-free guidance.
79
+ """
80
+ if force_drop_ids is None:
81
+ drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
82
+ else:
83
+ drop_ids = force_drop_ids == 1
84
+ labels = torch.where(drop_ids, self.num_classes, labels)
85
+ return labels, drop_ids
86
+
87
+ def forward(self, labels, train, force_drop_ids=None):
88
+ use_dropout = self.dropout_prob > 0
89
+ if (train and use_dropout) or (force_drop_ids is not None):
90
+ labels,drop_ids = self.token_drop(labels, force_drop_ids)
91
+ embeddings = self.embedding_table(labels).unsqueeze(1)
92
+ if (train and use_dropout) or (force_drop_ids is not None):
93
+ return embeddings,drop_ids
94
+ else:
95
+ return embeddings
96
+
97
+
98
+ class ConditionEmbedder(nn.Module):
99
+ """
100
+ Embeds Condition into vector representations. Also handles label dropout for classifier-free guidance.
101
+ """
102
+ def __init__(self, in_channels, hidden_size, uncond_prob, token_num=120, vocab_size=16384):
103
+ super().__init__()
104
+ self.cap_proj = MLP(in_features=hidden_size, hidden_features=hidden_size, out_features=hidden_size)
105
+ self.register_buffer("uncond_embedding", torch.zeros(token_num, hidden_size) / hidden_size ** 0.5)
106
+ self.uncond_prob = uncond_prob
107
+
108
+ def token_drop(self, caption, force_drop_ids=None, drop_ids=None):
109
+ """
110
+ Drops labels to enable classifier-free guidance.
111
+ """
112
+ if force_drop_ids is None:
113
+ if drop_ids is None:
114
+ drop_ids = torch.rand(caption.shape[0], device=caption.device) < self.uncond_prob
115
+ else:
116
+ drop_ids = force_drop_ids == 1
117
+
118
+ caption = torch.where(drop_ids[:, None, None], self.uncond_embedding[:caption.shape[1]], caption)
119
+ return caption
120
+
121
+ def forward(self, caption, train, force_drop_ids=None, drop_ids=None):
122
+ use_dropout = self.uncond_prob > 0
123
+ if (train and use_dropout) or (force_drop_ids is not None):
124
+ caption = self.token_drop(caption, force_drop_ids, drop_ids)
125
+ embeddings = self.cap_proj(caption)
126
+ return embeddings
127
+
128
+ #################################################################################
129
+ # Embedding Layers for Text Feature #
130
+ #################################################################################
131
+ class CaptionEmbedder(nn.Module):
132
+ """
133
+ Embeds text caption into vector representations. Also handles label dropout for classifier-free guidance.
134
+ """
135
+ def __init__(self, in_channels, hidden_size, uncond_prob, token_num=120):
136
+ super().__init__()
137
+ self.cap_proj = MLP(in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size)
138
+ self.register_buffer("uncond_embedding", nn.Parameter(torch.randn(token_num, in_channels) / in_channels ** 0.5))
139
+ self.uncond_prob = uncond_prob
140
+
141
+ def token_drop(self, caption, force_drop_ids=None):
142
+ """
143
+ Drops labels to enable classifier-free guidance.
144
+ """
145
+ if force_drop_ids is None:
146
+ drop_ids = torch.rand(caption.shape[0], device=caption.device) < self.uncond_prob
147
+ else:
148
+ drop_ids = force_drop_ids == 1
149
+ caption = torch.where(drop_ids[:, None, None], self.uncond_embedding, caption)
150
+ return caption, drop_ids
151
+
152
+ def forward(self, caption, train, force_drop_ids=None):
153
+ use_dropout = self.uncond_prob > 0
154
+ if (train and use_dropout) or (force_drop_ids is not None):
155
+ caption, drop_ids = self.token_drop(caption, force_drop_ids)
156
+ embeddings = self.cap_proj(caption)
157
+ if (train and use_dropout) or (force_drop_ids is not None):
158
+ return embeddings, drop_ids
159
+ else:
160
+ return embeddings
161
+
162
+
163
+ class MLP(nn.Module):
164
+ def __init__(self, in_features, hidden_features, out_features):
165
+ super().__init__()
166
+ out_features = out_features or in_features
167
+ hidden_features = hidden_features or in_features
168
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=False)
169
+ self.act = nn.GELU(approximate='tanh')
170
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=False)
171
+
172
+ nn.init.zeros_(self.fc1.weight)
173
+ nn.init.zeros_(self.fc2.weight)
174
+
175
+ def forward(self, x):
176
+ x = self.fc1(x)
177
+ x = self.act(x)
178
+ x = self.fc2(x)
179
+ return x
180
+
181
+
182
+ #################################################################################
183
+ # GPT Model #
184
+ #################################################################################
185
+ class RMSNorm(torch.nn.Module):
186
+ def __init__(self, dim: int, eps: float = 1e-5):
187
+ super().__init__()
188
+ self.eps = eps
189
+ self.weight = nn.Parameter(torch.ones(dim))
190
+
191
+ def _norm(self, x):
192
+ return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
193
+
194
+ def forward(self, x):
195
+ output = self._norm(x.float()).type_as(x)
196
+ return output * self.weight
197
+
198
+
199
+ class FeedForward(nn.Module):
200
+ def __init__(self, config: ModelArgs):
201
+ super().__init__()
202
+ hidden_dim = 4 * config.dim
203
+ hidden_dim = int(2 * hidden_dim / 3)
204
+ # custom dim factor multiplier
205
+ if config.ffn_dim_multiplier is not None:
206
+ hidden_dim = int(config.ffn_dim_multiplier * hidden_dim)
207
+ hidden_dim = find_multiple(hidden_dim, config.multiple_of)
208
+
209
+ self.w1 = nn.Linear(config.dim, hidden_dim, bias=False)
210
+ self.w3 = nn.Linear(config.dim, hidden_dim, bias=False)
211
+ self.w2 = nn.Linear(hidden_dim, config.dim, bias=False)
212
+ self.ffn_dropout = nn.Dropout(config.ffn_dropout_p)
213
+
214
+ def forward(self, x):
215
+ return self.ffn_dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
216
+
217
+
218
+ class KVCache(nn.Module):
219
+ def __init__(self, max_batch_size, max_seq_length, n_head, head_dim, dtype):
220
+ super().__init__()
221
+ cache_shape = (max_batch_size, n_head, max_seq_length, head_dim)
222
+ self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype))
223
+ self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype))
224
+
225
+ def update(self, input_pos, k_val, v_val):
226
+ # input_pos: [S], k_val: [B, H, S, D]
227
+ assert input_pos.shape[0] == k_val.shape[2]
228
+ k_out = self.k_cache
229
+ v_out = self.v_cache
230
+ k_out[:, :, input_pos] = k_val
231
+ v_out[:, :, input_pos] = v_val
232
+
233
+ return k_out, v_out
234
+
235
+
236
+ class Attention(nn.Module):
237
+ def __init__(self, config: ModelArgs):
238
+ super().__init__()
239
+ assert config.dim % config.n_head == 0
240
+ self.dim = config.dim
241
+ self.head_dim = config.dim // config.n_head
242
+ self.n_head = config.n_head
243
+ self.n_kv_head = config.n_kv_head if config.n_kv_head is not None else config.n_head
244
+ total_kv_dim = (self.n_head + 2 * self.n_kv_head) * self.head_dim
245
+
246
+ # key, query, value projections for all heads, but in a batch
247
+ self.wqkv = nn.Linear(config.dim, total_kv_dim, bias=False)
248
+ self.wo = nn.Linear(config.dim, config.dim, bias=False)
249
+ self.kv_cache = None
250
+
251
+ # regularization
252
+ self.attn_dropout_p = config.attn_dropout_p
253
+ self.resid_dropout = nn.Dropout(config.resid_dropout_p)
254
+
255
+ def forward(
256
+ self, x: torch.Tensor, freqs_cis: torch.Tensor = None,
257
+ input_pos: Optional[torch.Tensor] = None,
258
+ mask: Optional[torch.Tensor] = None
259
+ ):
260
+ bsz, seqlen, _ = x.shape
261
+ kv_size = self.n_kv_head * self.head_dim
262
+ xq, xk, xv = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
263
+
264
+ xq = xq.view(bsz, seqlen, self.n_head, self.head_dim)
265
+ xk = xk.view(bsz, seqlen, self.n_kv_head, self.head_dim)
266
+ xv = xv.view(bsz, seqlen, self.n_kv_head, self.head_dim)
267
+
268
+ xq = apply_rotary_emb(xq, freqs_cis)
269
+ xk = apply_rotary_emb(xk, freqs_cis)
270
+
271
+ xq, xk, xv = map(lambda x: x.transpose(1, 2), (xq, xk, xv))
272
+
273
+ if self.kv_cache is not None:
274
+ keys, values = self.kv_cache.update(input_pos, xk, xv)
275
+ else:
276
+ keys, values = xk, xv
277
+ keys = keys.repeat_interleave(self.n_head // self.n_kv_head, dim=1)
278
+ values = values.repeat_interleave(self.n_head // self.n_kv_head, dim=1)
279
+
280
+ output = F.scaled_dot_product_attention(
281
+ xq, keys, values,
282
+ attn_mask=mask,
283
+ is_causal=True if mask is None else False, # is_causal=False is for KV cache
284
+ dropout_p=self.attn_dropout_p if self.training else 0)
285
+
286
+ output = output.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
287
+
288
+ output = self.resid_dropout(self.wo(output))
289
+ return output
290
+
291
+
292
+ class TransformerBlock(nn.Module):
293
+ def __init__(self, config: ModelArgs, drop_path: float):
294
+ super().__init__()
295
+ self.attention = Attention(config)
296
+ self.feed_forward = FeedForward(config)
297
+ self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
298
+ self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
299
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
300
+
301
+ def forward(
302
+ self, x: torch.Tensor, freqs_cis: torch.Tensor, start_pos: int, mask: Optional[torch.Tensor] = None):
303
+ h = x + self.drop_path(self.attention(self.attention_norm(x), freqs_cis, start_pos, mask))
304
+ out = h + self.drop_path(self.feed_forward(self.ffn_norm(h)))
305
+ return out
306
+
307
+
308
+ class Transformer(nn.Module):
309
+ def __init__(self, config: ModelArgs):
310
+ super().__init__()
311
+ self.config = config
312
+ self.vocab_size = config.vocab_size
313
+ self.n_layer = config.n_layer
314
+ self.block_size = config.block_size
315
+ self.num_classes = config.num_classes
316
+ self.model_type = config.model_type
317
+ self.cls_token_num = config.cls_token_num
318
+ self.layer_internal = config.n_layer // 3
319
+ # self.adapter = Adapter(output_dim=768)
320
+ # self.adapter = ViT_Adapter()
321
+ # self.adapter = DeiT_Adapter()
322
+ self.adapter = Dinov2_Adapter(adapter_size=config.adapter_size, condition_type=config.condition_type)
323
+ # self.adapter = EVA_Adapter()
324
+ if config.adapter_size == "small":
325
+ self.adapter_mlp = MLP(384, config.dim, config.dim)
326
+ elif config.adapter_size == 'base':
327
+ self.adapter_mlp = MLP(768, config.dim, config.dim)
328
+
329
+ if self.model_type == 'c2i':
330
+ self.cls_embedding = LabelEmbedder(config.num_classes, config.dim, config.class_dropout_prob)
331
+ elif self.model_type == 't2i':
332
+ self.cls_embedding = CaptionEmbedder(config.caption_dim, config.dim, config.class_dropout_prob)
333
+ else:
334
+ raise Exception("please check model type")
335
+ self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
336
+ self.tok_dropout = nn.Dropout(config.token_dropout_p)
337
+
338
+ self.condition_embeddings = nn.Embedding(config.vocab_size, config.dim)
339
+ self.condition_mlp = ConditionEmbedder(self.block_size, config.dim, config.class_dropout_prob, self.block_size, config.vocab_size)
340
+ self.condition_layers = torch.nn.ModuleList()
341
+ for layer_id in range(3):
342
+ self.condition_layers.append(MLP(config.dim,config.dim,config.dim))
343
+
344
+ # transformer blocks
345
+ dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.n_layer)]
346
+ self.layers = torch.nn.ModuleList()
347
+ for layer_id in range(config.n_layer):
348
+ self.layers.append(TransformerBlock(config, dpr[layer_id]))
349
+
350
+ # output layer
351
+ self.norm = RMSNorm(config.dim, eps=config.norm_eps)
352
+ self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
353
+
354
+ # 2d rotary pos embedding
355
+ grid_size = int(self.block_size ** 0.5)
356
+ assert grid_size * grid_size == self.block_size
357
+ self.freqs_cis = precompute_freqs_cis_2d(grid_size, self.config.dim // self.config.n_head, self.config.rope_base, self.cls_token_num)
358
+
359
+ # KVCache
360
+ self.max_batch_size = -1
361
+ self.max_seq_length = -1
362
+
363
+ self.initialize_weights()
364
+ self.condition_token = None
365
+ self.mask = get_causal_mask(256)
366
+ self.global_token = None
367
+
368
+ self.control_strength = 1
369
+
370
+ def initialize_weights(self):
371
+ # Initialize nn.Linear and nn.Embedding
372
+ self.apply(self._init_weights)
373
+
374
+ # Zero-out output layers:
375
+ nn.init.constant_(self.output.weight, 0)
376
+
377
+
378
+
379
+ def _init_weights(self, module):
380
+ std = self.config.initializer_range
381
+ if isinstance(module, nn.Linear):
382
+ module.weight.data.normal_(mean=0.0, std=std)
383
+ if module.bias is not None:
384
+ module.bias.data.zero_()
385
+ elif isinstance(module, nn.Embedding):
386
+ module.weight.data.normal_(mean=0.0, std=std)
387
+
388
+
389
+ def setup_caches(self, max_batch_size, max_seq_length, dtype):
390
+ # if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size:
391
+ # return
392
+ head_dim = self.config.dim // self.config.n_head
393
+ max_seq_length = find_multiple(max_seq_length, 8) #
394
+ self.max_seq_length = max_seq_length
395
+ self.max_batch_size = max_batch_size
396
+ for b in self.layers:
397
+ b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_head, head_dim, dtype)
398
+
399
+ causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool))
400
+ self.causal_mask = causal_mask.unsqueeze(0).repeat(self.max_batch_size, 1, 1)
401
+ grid_size = int(self.config.block_size ** 0.5)
402
+ assert grid_size * grid_size == self.block_size
403
+ self.freqs_cis = precompute_freqs_cis_2d(grid_size, self.config.dim // self.config.n_head, self.config.rope_base, self.cls_token_num)
404
+
405
+
406
+ def forward(
407
+ self,
408
+ idx: torch.Tensor,
409
+ cond_idx: torch.Tensor, # cond_idx_or_embed
410
+ input_pos: Optional[torch.Tensor] = None,
411
+ targets: Optional[torch.Tensor] = None,
412
+ mask: Optional[torch.Tensor] = None,
413
+ valid: Optional[torch.Tensor] = None,
414
+ condition: Optional[torch.Tensor] = None,
415
+ control_strength: Optional[int] = 1
416
+ ):
417
+ if idx is not None and cond_idx is not None: # training or naive inference
418
+ cond_embeddings,drop_ids = self.cls_embedding(cond_idx, train=self.training)
419
+ cond_embeddings = cond_embeddings[:,:self.cls_token_num]
420
+ token_embeddings = self.tok_embeddings(idx)
421
+ if condition is not None:
422
+ condition_embeddings = self.adapter(condition)
423
+ condition_embeddings = self.adapter_mlp(condition_embeddings)
424
+ self.condition_token = self.condition_mlp(condition_embeddings,train=self.training, drop_ids=drop_ids)
425
+ token_embeddings = torch.cat((cond_embeddings, token_embeddings), dim=1)
426
+
427
+ h = self.tok_dropout(token_embeddings)
428
+ self.freqs_cis = self.freqs_cis.to(h.device)
429
+ else:
430
+ if cond_idx is not None: # prefill in inference
431
+ self.control_strength = control_strength
432
+ token_embeddings = self.cls_embedding(cond_idx, train=self.training)
433
+ token_embeddings = token_embeddings[:,:self.cls_token_num]
434
+ if condition is not None:
435
+ condition_embeddings = self.condition_mlp(condition, train=self.training)#.to(torch.bfloat16),train=self.training)
436
+ self.condition_token = condition_embeddings
437
+ self.condition_token = [self.condition_layers[0](self.condition_token),
438
+ self.condition_layers[1](self.condition_token),
439
+ self.condition_layers[2](self.condition_token)]
440
+
441
+ else: # decode_n_tokens(kv cache) in inference
442
+ token_embeddings = self.tok_embeddings(idx)
443
+ bs = token_embeddings.shape[0]
444
+ mask = self.causal_mask[:bs, None, input_pos]
445
+ h = self.tok_dropout(token_embeddings)
446
+ self.freqs_cis = self.freqs_cis
447
+
448
+ if self.training:
449
+ freqs_cis = self.freqs_cis[:token_embeddings.shape[1]]
450
+ else:
451
+ freqs_cis = self.freqs_cis[input_pos]
452
+ # transformer blocks
453
+ for i, layer in enumerate(self.layers):
454
+ if i%self.layer_internal == 0:
455
+ if self.training:
456
+ h[:, self.cls_token_num-1:] = h[:, self.cls_token_num-1:] + self.condition_layers[i//self.layer_internal](self.condition_token)
457
+ else:
458
+ if len(input_pos)>1:
459
+ # h[:, -1:] = h[:, -1:] + self.condition_layers[i//self.layer_internal](self.condition_token[:,0:1])
460
+ h[:,-1:] = h[:, -1:] + self.control_strength*self.condition_token[i//self.layer_internal][:,0:1]
461
+ else:
462
+ # h = h + self.condition_layers[i//self.layer_internal](self.condition_token[:,input_pos-self.cls_token_num+1])
463
+ h = h + self.control_strength*self.condition_token[i//self.layer_internal][:,input_pos-self.cls_token_num+1]
464
+ h = layer(h, freqs_cis, input_pos, mask)
465
+ # output layers
466
+ h = self.norm(h)
467
+ logits = self.output(h).float()
468
+
469
+ if self.training:
470
+ logits = logits[:, self.cls_token_num - 1:].contiguous()
471
+ # if we are given some desired targets also calculate the loss
472
+ loss = None
473
+ if valid is not None:
474
+ loss_all = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), reduction='none')
475
+ valid_all = valid[:,None].repeat(1, targets.shape[1]).view(-1)
476
+ loss = (loss_all * valid_all).sum() / max(valid_all.sum(), 1)
477
+ elif targets is not None:
478
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
479
+
480
+
481
+ return logits, loss
482
+
483
+
484
+ def get_fsdp_wrap_module_list(self) -> List[nn.Module]:
485
+ return list(self.layers)
486
+
487
+
488
+
489
+ #################################################################################
490
+ # Rotary Positional Embedding Functions #
491
+ #################################################################################
492
+ # https://github.com/pytorch-labs/gpt-fast/blob/main/model.py
493
+ def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000, cls_token_num=120):
494
+ freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem))
495
+ t = torch.arange(seq_len, device=freqs.device)
496
+ freqs = torch.outer(t, freqs) # (seq_len, head_dim // 2)
497
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
498
+ cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) # (cls_token_num+seq_len, head_dim // 2, 2)
499
+ cond_cache = torch.cat([torch.zeros(cls_token_num, n_elem // 2, 2), cache]) # (cls_token_num+seq_len, head_dim // 2, 2)
500
+ return cond_cache
501
+
502
+
503
+ def precompute_freqs_cis_2d(grid_size: int, n_elem: int, base: int = 10000, cls_token_num=120):
504
+ # split the dimension into half, one for x and one for y
505
+ half_dim = n_elem // 2
506
+ freqs = 1.0 / (base ** (torch.arange(0, half_dim, 2)[: (half_dim // 2)].float() / half_dim))
507
+ t = torch.arange(grid_size, device=freqs.device)
508
+ freqs = torch.outer(t, freqs) # (grid_size, head_dim // 2)
509
+ freqs_grid = torch.concat([
510
+ freqs[:, None, :].expand(-1, grid_size, -1),
511
+ freqs[None, :, :].expand(grid_size, -1, -1),
512
+ ], dim=-1) # (grid_size, grid_size, head_dim // 2)
513
+ cache_grid = torch.stack([torch.cos(freqs_grid), torch.sin(freqs_grid)], dim=-1) # (grid_size, grid_size, head_dim // 2, 2)
514
+ cache = cache_grid.flatten(0, 1)
515
+ cond_cache = torch.cat([torch.zeros(cls_token_num, n_elem // 2, 2), cache]) # (cls_token_num+grid_size**2, head_dim // 2, 2)
516
+ return cond_cache
517
+
518
+ def precompute_freqs_cis_2d_new(grid_size: int, n_elem: int, base: int = 10000, cls_token_num=120, spe_token_num=3, ar_token_num=4):
519
+ # split the dimension into half, one for x and one for y
520
+ half_dim = n_elem // 2
521
+ freqs = 1.0 / (base ** (torch.arange(0, half_dim, 2)[: (half_dim // 2)].float() / half_dim))
522
+ t = torch.arange(grid_size, device=freqs.device)
523
+ freqs = torch.outer(t, freqs) # (grid_size, head_dim // 2)
524
+ freqs_grid = torch.concat([
525
+ freqs[:, None, :].expand(-1, grid_size, -1),
526
+ freqs[None, :, :].expand(grid_size, -1, -1),
527
+ ], dim=-1) # (grid_size, grid_size, head_dim // 2)
528
+ cache_grid = torch.stack([torch.cos(freqs_grid), torch.sin(freqs_grid)], dim=-1) # (grid_size, grid_size, head_dim // 2, 2)
529
+ sub_num = int(ar_token_num**0.5)
530
+
531
+ cache_grid = cache_grid.reshape(sub_num, grid_size//sub_num, sub_num, grid_size//sub_num, half_dim, 2)
532
+ cache_grid = cache_grid.permute(1, 3, 0, 2, 4, 5)
533
+ cache = cache_grid.flatten(0, 3)
534
+ cache_one, cache_two = cache[:ar_token_num], cache[ar_token_num:]
535
+ sep_cache = torch.zeros(spe_token_num, n_elem // 2, 2)
536
+ cond_cache = torch.cat([torch.zeros(cls_token_num, n_elem // 2, 2), cache_one, sep_cache, cache_two])
537
+ # cond_cache = torch.cat([torch.zeros(cls_token_num, n_elem // 2, 2), cache]) # (cls_token_num+grid_size**2, head_dim // 2, 2)
538
+ return cond_cache
539
+
540
+
541
+ def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor):
542
+ # x: (bs, seq_len, n_head, head_dim)
543
+ # freqs_cis (seq_len, head_dim // 2, 2)
544
+ xshaped = x.float().reshape(*x.shape[:-1], -1, 2) # (bs, seq_len, n_head, head_dim//2, 2)
545
+ freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2) # (1, seq_len, 1, head_dim//2, 2)
546
+ x_out2 = torch.stack([
547
+ xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
548
+ xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
549
+ ], dim=-1)
550
+ x_out2 = x_out2.flatten(3)
551
+ return x_out2.type_as(x)
552
+
553
+
554
+
555
+ #################################################################################
556
+ # GPT Configs #
557
+ #################################################################################
558
+ ### text-conditional
559
+ def GPT_7B(**kwargs):
560
+ return Transformer(ModelArgs(n_layer=32, n_head=32, dim=4096, **kwargs)) # 6.6B
561
+
562
+ def GPT_3B(**kwargs):
563
+ return Transformer(ModelArgs(n_layer=24, n_head=32, dim=3200, **kwargs)) # 3.1B
564
+
565
+ def GPT_1B(**kwargs):
566
+ return Transformer(ModelArgs(n_layer=22, n_head=32, dim=2048, **kwargs)) # 1.2B
567
+
568
+ ### class-conditional
569
+ def GPT_XXXL(**kwargs):
570
+ return Transformer(ModelArgs(n_layer=48, n_head=40, dim=2560, **kwargs)) # 3.9B
571
+
572
+ def GPT_XXL(**kwargs):
573
+ return Transformer(ModelArgs(n_layer=48, n_head=24, dim=1536, **kwargs)) # 1.4B
574
+
575
+ def GPT_XL(**kwargs):
576
+ return Transformer(ModelArgs(n_layer=36, n_head=20, dim=1280, **kwargs)) # 775M
577
+
578
+ def GPT_L(**kwargs):
579
+ return Transformer(ModelArgs(n_layer=24, n_head=16, dim=1024, **kwargs)) # 343M
580
+
581
+ def GPT_B(**kwargs):
582
+ return Transformer(ModelArgs(n_layer=12, n_head=12, dim=768, **kwargs)) # 111M
583
+
584
+
585
+ GPT_models = {
586
+ 'GPT-B': GPT_B, 'GPT-L': GPT_L, 'GPT-XL': GPT_XL, 'GPT-XXL': GPT_XXL, 'GPT-XXXL': GPT_XXXL,
587
+ 'GPT-1B': GPT_1B, 'GPT-3B': GPT_3B, 'GPT-7B': GPT_7B,
588
+ }
CondRefAR/pipeline.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from PIL import Image
4
+ from safetensors.torch import load_file
5
+ from .models.gpt_t2i import GPT_models
6
+ from .models.generate import generate
7
+ from .tokenizer.vq_model import VQ_models
8
+
9
+ class CondRefARPipeline:
10
+ def __init__(self, device=None, torch_dtype=torch.bfloat16):
11
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
12
+ self.dtype = torch_dtype
13
+ self.gpt = None
14
+ self.vq = None
15
+ self.image_size = None
16
+ self.downsample = None
17
+ self.n_q = 8
18
+
19
+ @classmethod
20
+ def from_pretrained(cls, repo_or_path, gpt_config, vq_config, gpt_weights="weights/sketch-gpt-xl.safetensors", vq_weights="weights/vq-16.safetensors", device=None, torch_dtype=torch.bfloat16):
21
+ pipe = cls(device=device, torch_dtype=torch_dtype)
22
+
23
+ # 1) VQ
24
+ pipe.downsample = int(vq_config["downsample_size"])
25
+ codebook_size = int(vq_config["codebook_size"])
26
+ codebook_embed_dim = int(vq_config["codebook_embed_dim"])
27
+ pipe.vq = VQ_models[vq_config.get("model_name", "VQ-16")](codebook_size=codebook_size, codebook_embed_dim=codebook_embed_dim)
28
+ vq_state = load_file(f"{repo_or_path}/{vq_weights}")
29
+ pipe.vq.load_state_dict(vq_state, strict=True)
30
+ pipe.vq.to(pipe.device)
31
+ pipe.vq.eval()
32
+
33
+ # 2) GPT
34
+ pipe.image_size = int(gpt_config["image_size"])
35
+ vocab_size = int(gpt_config["vocab_size"])
36
+ latent_size = pipe.image_size // pipe.downsample
37
+ block_size=latent_size ** 2
38
+ num_classes = int(gpt_config.get("num_classes", 1000))
39
+ cls_token_num = int(gpt_config.get("cls_token_num", 120))
40
+ model_type = gpt_config.get("model_type", "t2i")
41
+ adapter_size = gpt_config.get("adapter_size", "small")
42
+ condition_type = gpt_config.get("condition_type", "sketch")
43
+
44
+
45
+ pipe.gpt = GPT_models[gpt_config.get("gpt_name", "GPT-XL")](
46
+ vocab_size=vocab_size,
47
+ block_size=block_size,
48
+ num_classes=num_classes,
49
+ cls_token_num=cls_token_num,
50
+ model_type=model_type,
51
+ adapter_size=adapter_size,
52
+ condition_type=condition_type
53
+ ).to(device=pipe.device, dtype=pipe.dtype)
54
+ gpt_state = load_file(f"{repo_or_path}/{gpt_weights}")
55
+ pipe.gpt.load_state_dict(gpt_state, strict=False)
56
+ pipe.gpt.eval()
57
+
58
+ return pipe
59
+
60
+ @torch.inference_mode()
61
+ def __call__(self, prompt_emb, control_image, cfg_scale=4, cfg_interval=-1, temperature=1.0, top_k=2000, top_p=1.0):
62
+ """
63
+ prompt_emb: torch.Tensor [B, T_txt, D]
64
+ control_image: np.ndarray/PIL
65
+ Return: Image
66
+ """
67
+ # 预处理 control
68
+ if isinstance(control_image, Image.Image):
69
+ control_image = np.array(control_image.convert("RGB"))
70
+ if isinstance(control_image, np.ndarray):
71
+ # [H,W,C] uint8 -> [-1,1]
72
+ control_image = torch.from_numpy(control_image).permute(2,0,1).unsqueeze(0).float()
73
+ if control_image.max() > 1.0:
74
+ control_image = control_image / 255.0
75
+ control_image = 2.0 * (control_image - 0.5)
76
+ control = control_image.to(self.device, dtype=self.dtype)
77
+ # 文本嵌入
78
+ c_indices = prompt_emb.to(self.device, dtype=self.dtype)
79
+ # 这里的 emb_mask 若需要,可在外部构造后传入;为了最小示例,这里置 None
80
+ c_emb_masks = None
81
+
82
+ Hq = self.image_size // self.downsample
83
+ Wq = Hq
84
+ seq_len = Hq * Wq
85
+ # 采样 codebook 索引序列(generate 返回 [B, n_q*Hq*Wq] 或 [B, seq_len] 逐 codebook 生成)
86
+ index_sample = generate(
87
+ self.gpt, c_indices, seq_len, c_emb_masks,
88
+ condition=control, cfg_scale=cfg_scale, cfg_interval=cfg_interval,
89
+ temperature=temperature, top_k=top_k, top_p=top_p, sample_logits=True
90
+ )
91
+ # 重排 [B, n_q, Hq, Wq]
92
+ if index_sample.dim() == 2 and index_sample.shape[1] == self.n_q * Hq * Wq:
93
+ tokens = index_sample.view(index_sample.size(0), self.n_q, Hq, Wq).long()
94
+ elif index_sample.dim() == 2 and index_sample.shape[1] == Hq * Wq:
95
+ tokens = index_sample.view(index_sample.size(0), 1, Hq, Wq).long()
96
+ else:
97
+ # 尝试自动推断 n_q
98
+ n_q = max(1, index_sample.shape[1] // (Hq * Wq))
99
+ tokens = index_sample[:, : n_q * Hq * Wq].view(index_sample.size(0), n_q, Hq, Wq).long()
100
+ tokens = tokens.to(self.device)
101
+ qzshape = [tokens.size(0), 8, Hq, Wq]
102
+ samples = self.vq.decode_code(tokens, qzshape).detach().float().cpu()
103
+ # [-1,1] -> [0,1]
104
+ if samples.min() < -0.9:
105
+ samples = (samples + 1.0) / 2.0
106
+ samples = samples.clamp(0, 1)
107
+
108
+ imgs = []
109
+ arr = (samples * 255).to(torch.uint8).permute(0,2,3,1).numpy()
110
+ for i in range(arr.shape[0]):
111
+ imgs.append(Image.fromarray(arr[i]))
112
+ return imgs
CondRefAR/tokenizer/vq_model.py ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from:
2
+ # taming-transformers: https://github.com/CompVis/taming-transformers
3
+ # maskgit: https://github.com/google-research/maskgit
4
+ from dataclasses import dataclass, field
5
+ from typing import List
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+
12
+ @dataclass
13
+ class ModelArgs:
14
+ codebook_size: int = 16384
15
+ codebook_embed_dim: int = 8
16
+ codebook_l2_norm: bool = True
17
+ codebook_show_usage: bool = True
18
+ commit_loss_beta: float = 0.25
19
+ entropy_loss_ratio: float = 0.0
20
+
21
+ encoder_ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4])
22
+ decoder_ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4])
23
+ z_channels: int = 256
24
+ dropout_p: float = 0.0
25
+
26
+
27
+
28
+ class VQModel(nn.Module):
29
+ def __init__(self, config: ModelArgs):
30
+ super().__init__()
31
+ self.config = config
32
+ self.encoder = Encoder(ch_mult=config.encoder_ch_mult, z_channels=config.z_channels, dropout=config.dropout_p)
33
+ self.decoder = Decoder(ch_mult=config.decoder_ch_mult, z_channels=config.z_channels, dropout=config.dropout_p)
34
+
35
+ self.quantize = VectorQuantizer(config.codebook_size, config.codebook_embed_dim,
36
+ config.commit_loss_beta, config.entropy_loss_ratio,
37
+ config.codebook_l2_norm, config.codebook_show_usage)
38
+ self.quant_conv = nn.Conv2d(config.z_channels, config.codebook_embed_dim, 1)
39
+ self.post_quant_conv = nn.Conv2d(config.codebook_embed_dim, config.z_channels, 1)
40
+
41
+ def encode(self, x):
42
+ #import pdb; pdb.set_trace()
43
+ h = self.encoder(x)
44
+ h = self.quant_conv(h)
45
+ quant, emb_loss, info = self.quantize(h)
46
+ return quant, emb_loss, info
47
+
48
+ def decode(self, quant):
49
+ quant = self.post_quant_conv(quant)
50
+ dec = self.decoder(quant)
51
+ return dec
52
+
53
+ def decode_code(self, code_b, shape=None, channel_first=True):
54
+ quant_b = self.quantize.get_codebook_entry(code_b, shape, channel_first)
55
+ dec = self.decode(quant_b)
56
+ return dec
57
+
58
+ def forward(self, input):
59
+ quant, diff, _ = self.encode(input)
60
+ dec = self.decode(quant)
61
+ return dec, diff
62
+
63
+
64
+
65
+ class Encoder(nn.Module):
66
+ def __init__(self, in_channels=3, ch=128, ch_mult=(1,1,2,2,4), num_res_blocks=2,
67
+ norm_type='group', dropout=0.0, resamp_with_conv=True, z_channels=256):
68
+ super().__init__()
69
+ self.num_resolutions = len(ch_mult)
70
+ self.num_res_blocks = num_res_blocks
71
+ self.conv_in = nn.Conv2d(in_channels, ch, kernel_size=3, stride=1, padding=1)
72
+
73
+ # downsampling
74
+ in_ch_mult = (1,) + tuple(ch_mult)
75
+ self.conv_blocks = nn.ModuleList()
76
+ for i_level in range(self.num_resolutions):
77
+ conv_block = nn.Module()
78
+ # res & attn
79
+ res_block = nn.ModuleList()
80
+ attn_block = nn.ModuleList()
81
+ block_in = ch*in_ch_mult[i_level]
82
+ block_out = ch*ch_mult[i_level]
83
+ for _ in range(self.num_res_blocks):
84
+ res_block.append(ResnetBlock(block_in, block_out, dropout=dropout, norm_type=norm_type))
85
+ block_in = block_out
86
+ if i_level == self.num_resolutions - 1:
87
+ attn_block.append(AttnBlock(block_in, norm_type))
88
+ conv_block.res = res_block
89
+ conv_block.attn = attn_block
90
+ # downsample
91
+ if i_level != self.num_resolutions-1:
92
+ conv_block.downsample = Downsample(block_in, resamp_with_conv)
93
+ self.conv_blocks.append(conv_block)
94
+
95
+ # middle
96
+ self.mid = nn.ModuleList()
97
+ self.mid.append(ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type))
98
+ self.mid.append(AttnBlock(block_in, norm_type=norm_type))
99
+ self.mid.append(ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type))
100
+
101
+ # end
102
+ self.norm_out = Normalize(block_in, norm_type)
103
+ self.conv_out = nn.Conv2d(block_in, z_channels, kernel_size=3, stride=1, padding=1)
104
+
105
+
106
+ def forward(self, x):
107
+ h = self.conv_in(x)
108
+ # downsampling
109
+ for i_level, block in enumerate(self.conv_blocks):
110
+ for i_block in range(self.num_res_blocks):
111
+ h = block.res[i_block](h)
112
+ if len(block.attn) > 0:
113
+ h = block.attn[i_block](h)
114
+ if i_level != self.num_resolutions - 1:
115
+ h = block.downsample(h)
116
+
117
+ # middle
118
+ for mid_block in self.mid:
119
+ h = mid_block(h)
120
+
121
+ # end
122
+ h = self.norm_out(h)
123
+ h = nonlinearity(h)
124
+ h = self.conv_out(h)
125
+ return h
126
+
127
+
128
+
129
+ class Decoder(nn.Module):
130
+ def __init__(self, z_channels=256, ch=128, ch_mult=(1,1,2,2,4), num_res_blocks=2, norm_type="group",
131
+ dropout=0.0, resamp_with_conv=True, out_channels=3):
132
+ super().__init__()
133
+ self.num_resolutions = len(ch_mult)
134
+ self.num_res_blocks = num_res_blocks
135
+
136
+ block_in = ch*ch_mult[self.num_resolutions-1]
137
+ # z to block_in
138
+ self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
139
+
140
+ # middle
141
+ self.mid = nn.ModuleList()
142
+ self.mid.append(ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type))
143
+ self.mid.append(AttnBlock(block_in, norm_type=norm_type))
144
+ self.mid.append(ResnetBlock(block_in, block_in, dropout=dropout, norm_type=norm_type))
145
+
146
+ # upsampling
147
+ self.conv_blocks = nn.ModuleList()
148
+ for i_level in reversed(range(self.num_resolutions)):
149
+ conv_block = nn.Module()
150
+ # res & attn
151
+ res_block = nn.ModuleList()
152
+ attn_block = nn.ModuleList()
153
+ block_out = ch*ch_mult[i_level]
154
+ for _ in range(self.num_res_blocks + 1):
155
+ res_block.append(ResnetBlock(block_in, block_out, dropout=dropout, norm_type=norm_type))
156
+ block_in = block_out
157
+ if i_level == self.num_resolutions - 1:
158
+ attn_block.append(AttnBlock(block_in, norm_type))
159
+ conv_block.res = res_block
160
+ conv_block.attn = attn_block
161
+ # downsample
162
+ if i_level != 0:
163
+ conv_block.upsample = Upsample(block_in, resamp_with_conv)
164
+ self.conv_blocks.append(conv_block)
165
+
166
+ # end
167
+ self.norm_out = Normalize(block_in, norm_type)
168
+ self.conv_out = nn.Conv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1)
169
+
170
+ @property
171
+ def last_layer(self):
172
+ return self.conv_out.weight
173
+
174
+ def forward(self, z):
175
+ # z to block_in
176
+ h = self.conv_in(z)
177
+
178
+ # middle
179
+ for mid_block in self.mid:
180
+ h = mid_block(h)
181
+
182
+ # upsampling
183
+ for i_level, block in enumerate(self.conv_blocks):
184
+ for i_block in range(self.num_res_blocks + 1):
185
+ h = block.res[i_block](h)
186
+ if len(block.attn) > 0:
187
+ h = block.attn[i_block](h)
188
+ if i_level != self.num_resolutions - 1:
189
+ h = block.upsample(h)
190
+
191
+ # end
192
+ h = self.norm_out(h)
193
+ h = nonlinearity(h)
194
+ h = self.conv_out(h)
195
+ return h
196
+
197
+
198
+ class VectorQuantizer(nn.Module):
199
+ def __init__(self, n_e, e_dim, beta, entropy_loss_ratio, l2_norm, show_usage):
200
+ super().__init__()
201
+ self.n_e = n_e
202
+ self.e_dim = e_dim
203
+ self.beta = beta
204
+ self.entropy_loss_ratio = entropy_loss_ratio
205
+ self.l2_norm = l2_norm
206
+ self.show_usage = show_usage
207
+
208
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
209
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
210
+ if self.l2_norm:
211
+ self.embedding.weight.data = F.normalize(self.embedding.weight.data, p=2, dim=-1)
212
+ if self.show_usage:
213
+ self.register_buffer("codebook_used", nn.Parameter(torch.zeros(65536)))
214
+
215
+
216
+ def forward(self, z):
217
+ # reshape z -> (batch, height, width, channel) and flatten
218
+ z = torch.einsum('b c h w -> b h w c', z).contiguous()
219
+ z_flattened = z.view(-1, self.e_dim)
220
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
221
+
222
+ if self.l2_norm:
223
+ z = F.normalize(z, p=2, dim=-1)
224
+ z_flattened = F.normalize(z_flattened, p=2, dim=-1)
225
+ embedding = F.normalize(self.embedding.weight, p=2, dim=-1)
226
+ else:
227
+ embedding = self.embedding.weight
228
+
229
+ d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
230
+ torch.sum(embedding**2, dim=1) - 2 * \
231
+ torch.einsum('bd,dn->bn', z_flattened, torch.einsum('n d -> d n', embedding))
232
+
233
+ min_encoding_indices = torch.argmin(d, dim=1)
234
+ z_q = embedding[min_encoding_indices].view(z.shape)
235
+ perplexity = None
236
+ min_encodings = None
237
+ vq_loss = None
238
+ commit_loss = None
239
+ entropy_loss = None
240
+ codebook_usage = 0
241
+
242
+ if self.show_usage and self.training:
243
+ cur_len = min_encoding_indices.shape[0]
244
+ self.codebook_used[:-cur_len] = self.codebook_used[cur_len:].clone()
245
+ self.codebook_used[-cur_len:] = min_encoding_indices
246
+ codebook_usage = len(torch.unique(self.codebook_used)) / self.n_e
247
+
248
+ # compute loss for embedding
249
+ if self.training:
250
+ vq_loss = torch.mean((z_q - z.detach()) ** 2)
251
+ commit_loss = self.beta * torch.mean((z_q.detach() - z) ** 2)
252
+ entropy_loss = self.entropy_loss_ratio * compute_entropy_loss(-d)
253
+
254
+ # preserve gradients
255
+ z_q = z + (z_q - z).detach()
256
+
257
+ # reshape back to match original input shape
258
+ z_q = torch.einsum('b h w c -> b c h w', z_q)
259
+
260
+ return z_q, (vq_loss, commit_loss, entropy_loss, codebook_usage), (perplexity, min_encodings, min_encoding_indices)
261
+
262
+ def get_codebook_entry(self, indices, shape=None, channel_first=True):
263
+ # shape = (batch, channel, height, width) if channel_first else (batch, height, width, channel)
264
+ if self.l2_norm:
265
+ embedding = F.normalize(self.embedding.weight, p=2, dim=-1)
266
+ else:
267
+ embedding = self.embedding.weight
268
+ z_q = embedding[indices] # (b*h*w, c)
269
+
270
+ if shape is not None:
271
+ if channel_first:
272
+ z_q = z_q.reshape(shape[0], shape[2], shape[3], shape[1])
273
+ # reshape back to match original input shape
274
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
275
+ else:
276
+ z_q = z_q.view(shape)
277
+ return z_q
278
+
279
+
280
+ class ResnetBlock(nn.Module):
281
+ def __init__(self, in_channels, out_channels=None, conv_shortcut=False, dropout=0.0, norm_type='group'):
282
+ super().__init__()
283
+ self.in_channels = in_channels
284
+ out_channels = in_channels if out_channels is None else out_channels
285
+ self.out_channels = out_channels
286
+ self.use_conv_shortcut = conv_shortcut
287
+
288
+ self.norm1 = Normalize(in_channels, norm_type)
289
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
290
+ self.norm2 = Normalize(out_channels, norm_type)
291
+ self.dropout = nn.Dropout(dropout)
292
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
293
+
294
+ if self.in_channels != self.out_channels:
295
+ if self.use_conv_shortcut:
296
+ self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
297
+ else:
298
+ self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
299
+
300
+ def forward(self, x):
301
+ h = x
302
+ h = self.norm1(h)
303
+ h = nonlinearity(h)
304
+ h = self.conv1(h)
305
+ h = self.norm2(h)
306
+ h = nonlinearity(h)
307
+ h = self.dropout(h)
308
+ h = self.conv2(h)
309
+
310
+ if self.in_channels != self.out_channels:
311
+ if self.use_conv_shortcut:
312
+ x = self.conv_shortcut(x)
313
+ else:
314
+ x = self.nin_shortcut(x)
315
+ return x+h
316
+
317
+
318
+ class AttnBlock(nn.Module):
319
+ def __init__(self, in_channels, norm_type='group'):
320
+ super().__init__()
321
+ self.norm = Normalize(in_channels, norm_type)
322
+ self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
323
+ self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
324
+ self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
325
+ self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
326
+
327
+
328
+ def forward(self, x):
329
+ h_ = x
330
+ h_ = self.norm(h_)
331
+ q = self.q(h_)
332
+ k = self.k(h_)
333
+ v = self.v(h_)
334
+
335
+ # compute attention
336
+ b,c,h,w = q.shape
337
+ q = q.reshape(b,c,h*w)
338
+ q = q.permute(0,2,1) # b,hw,c
339
+ k = k.reshape(b,c,h*w) # b,c,hw
340
+ w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
341
+ w_ = w_ * (int(c)**(-0.5))
342
+ w_ = F.softmax(w_, dim=2)
343
+
344
+ # attend to values
345
+ v = v.reshape(b,c,h*w)
346
+ w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
347
+ h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
348
+ h_ = h_.reshape(b,c,h,w)
349
+
350
+ h_ = self.proj_out(h_)
351
+
352
+ return x+h_
353
+
354
+
355
+ def nonlinearity(x):
356
+ # swish
357
+ return x*torch.sigmoid(x)
358
+
359
+
360
+ def Normalize(in_channels, norm_type='group'):
361
+ assert norm_type in ['group', 'batch']
362
+ if norm_type == 'group':
363
+ return nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
364
+ elif norm_type == 'batch':
365
+ return nn.SyncBatchNorm(in_channels)
366
+
367
+
368
+ class Upsample(nn.Module):
369
+ def __init__(self, in_channels, with_conv):
370
+ super().__init__()
371
+ self.with_conv = with_conv
372
+ if self.with_conv:
373
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
374
+
375
+ def forward(self, x):
376
+ x = F.interpolate(x, scale_factor=2.0, mode="nearest")
377
+ if self.with_conv:
378
+ x = self.conv(x)
379
+ return x
380
+
381
+
382
+ class Downsample(nn.Module):
383
+ def __init__(self, in_channels, with_conv):
384
+ super().__init__()
385
+ self.with_conv = with_conv
386
+ if self.with_conv:
387
+ # no asymmetric padding in torch conv, must do it ourselves
388
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
389
+
390
+ def forward(self, x):
391
+ if self.with_conv:
392
+ pad = (0,1,0,1)
393
+ x = F.pad(x, pad, mode="constant", value=0)
394
+ x = self.conv(x)
395
+ else:
396
+ x = F.avg_pool2d(x, kernel_size=2, stride=2)
397
+ return x
398
+
399
+
400
+ def compute_entropy_loss(affinity, loss_type="softmax", temperature=0.01):
401
+ flat_affinity = affinity.reshape(-1, affinity.shape[-1])
402
+ flat_affinity /= temperature
403
+ probs = F.softmax(flat_affinity, dim=-1)
404
+ log_probs = F.log_softmax(flat_affinity + 1e-5, dim=-1)
405
+ if loss_type == "softmax":
406
+ target_probs = probs
407
+ else:
408
+ raise ValueError("Entropy loss {} not supported".format(loss_type))
409
+ avg_probs = torch.mean(target_probs, dim=0)
410
+ avg_entropy = - torch.sum(avg_probs * torch.log(avg_probs + 1e-5))
411
+ sample_entropy = - torch.mean(torch.sum(target_probs * log_probs, dim=-1))
412
+ loss = sample_entropy - avg_entropy
413
+ return loss
414
+
415
+
416
+ #################################################################################
417
+ # VQ Model Configs #
418
+ #################################################################################
419
+ def VQ_8(**kwargs):
420
+ return VQModel(ModelArgs(encoder_ch_mult=[1, 2, 2, 4], decoder_ch_mult=[1, 2, 2, 4], **kwargs))
421
+
422
+ def VQ_16(**kwargs):
423
+ return VQModel(ModelArgs(encoder_ch_mult=[1, 1, 2, 2, 4], decoder_ch_mult=[1, 1, 2, 2, 4], **kwargs))
424
+
425
+ VQ_models = {'VQ-16': VQ_16, 'VQ-8': VQ_8}
CondRefAR/utils/drop_path.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from timm.models.layers import DropPath
2
+ import torch
3
+
4
+ def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
5
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
6
+
7
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
8
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
9
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
10
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
11
+ 'survival rate' as the argument.
12
+
13
+ """
14
+ if drop_prob == 0. or not training:
15
+ return x
16
+ keep_prob = 1 - drop_prob
17
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
18
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
19
+ if keep_prob > 0.0 and scale_by_keep:
20
+ random_tensor.div_(keep_prob)
21
+ return x * random_tensor
22
+
23
+
24
+ class DropPath(torch.nn.Module):
25
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
26
+ """
27
+ def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
28
+ super(DropPath, self).__init__()
29
+ self.drop_prob = drop_prob
30
+ self.scale_by_keep = scale_by_keep
31
+
32
+ def forward(self, x):
33
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
34
+
35
+ def extra_repr(self):
36
+ return f'drop_prob={round(self.drop_prob,3):0.3f}'
README.md CHANGED
@@ -1,3 +1,78 @@
1
  ---
 
 
 
 
 
 
2
  license: apache-2.0
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ library_name: pytorch
3
+ tags:
4
+ - autoregressive
5
+ - image-generation
6
+ - aerial
7
+ - controllable-generation
8
  license: apache-2.0
9
+ pipeline_tag: image-to-image
10
  ---
11
+
12
+ # CondRef-AR: Condition-as-a-Reference Randomized Autoregressive Modelling for Controllable Aerial Image Generation
13
+
14
+ This repository contains the code and pretrained models for **CondRef-AR**, a controllable aerial image generation model using condition-as-a-reference randomized autoregressive modeling. The model generates high-quality aerial images based on input conditions such as sketches or segmentation maps.
15
+
16
+ ![CondRef-AR Overview](assets/method.jpg)
17
+
18
+
19
+ ## Quickstart
20
+
21
+ ```python
22
+ import json, torch
23
+ from CondRefAR.pipeline import CondRefARPipeline
24
+ from transformers import AutoTokenizer, T5EncoderModel
25
+
26
+ device = "cuda" if torch.cuda.is_available() else "cpu"
27
+ dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
28
+
29
+ gpt_cfg = json.load(open("configs/gpt_config.json"))
30
+ vq_cfg = json.load(open("configs/vq_config.json"))
31
+ pipe = CondRefARPipeline.from_pretrained(".", gpt_cfg, vq_cfg, device=device, torch_dtype=dtype)
32
+
33
+ tok = AutoTokenizer.from_pretrained("google/flan-t5-xl")
34
+ enc = T5EncoderModel.from_pretrained("google/flan-t5-xl", torch_dtype=dtype).to(device).eval()
35
+
36
+ prompt = "Aaerial view of a forested area with a river running through it. On the right side of the image, there is a small town or village with a red-roofed building."
37
+ control = "assets/examples/example2.jpg"
38
+
39
+ from PIL import Image, ImageOps
40
+ control_img = Image.open(control).convert("RGB")
41
+
42
+ inputs = tok([prompt], return_tensors="pt", padding="max_length", truncation=True, max_length=120)
43
+ with torch.no_grad():
44
+ emb = enc(input_ids=inputs["input_ids"].to(device), attention_mask=inputs["attention_mask"].to(device)).last_hidden_state
45
+
46
+ imgs = pipe(emb, control_img, cfg_scale=4, temperature=1.0, top_k=2000, top_p=1.0)
47
+ imgs[0].save("sample.png")
48
+ ```
49
+
50
+ ## Sample Results
51
+ By varying the input conditions and prompts, CondRef-AR can generate diverse aerial images:
52
+ ![Samples](assets/samples.png)
53
+
54
+ ConRef-AR can generate continuous, plausible, and high-resolution sequences of land-use change images based on a series of temporal semantic condition graphs. As shown in the figure below, the model successfully simulates the entire process—from a pristine forest gradually transforming into a modern residential urban area:
55
+
56
+ ![Temporal Generation](assets/evolution.png)
57
+ <div align="center">
58
+
59
+ | Control image | Aerial image |
60
+ |---|---|
61
+ | <img src="assets/control_img.gif" alt="control animation" width="100%"/> | <img src="assets/aerial_img.gif" alt="aerial animation" width="100%"/> |
62
+
63
+ </div>
64
+
65
+ ## Files
66
+ - `weights/sketch-gpt-xl.safetensors`, `weights/vq-16.safetensors`: pretrained weight
67
+ - `configs/*.json`: model hyperparameters.
68
+ - `CondRefAR/*`: inference code and pipeline.
69
+ - `assets/example`: example images.
70
+ - `app.py`: Gradio demo.
71
+
72
+ ## Notes
73
+ - Requires a GPU with bfloat16 support for best speed; CPU works but slow.
74
+ - CFG params: `cfg_scale`, `temperature`, `top_k`, `top_p` control quality vs diversity.
75
+ - If you have any questions, please open an issue, or contact [email protected].
76
+
77
+ ## License
78
+ Apache-2.0 (adjust if different).
app.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import json
4
+ from CondRefAR.pipeline import CondRefARPipeline
5
+ from transformers import AutoTokenizer, T5EncoderModel
6
+
7
+ # 简化:直接用 transformers 的 flan-t5-xl 提取文本嵌入
8
+ def build_t5(device, dtype):
9
+ tok = AutoTokenizer.from_pretrained("google/flan-t5-xl")
10
+ enc = T5EncoderModel.from_pretrained("google/flan-t5-xl", torch_dtype=dtype)
11
+ enc = enc.to(device)
12
+ enc.eval()
13
+ return tok, enc
14
+
15
+ def text_to_emb(prompt, tok, enc, device, dtype):
16
+ inputs = tok([prompt], return_tensors="pt", padding='max_length', truncation=True, return_attention_mask=True, add_special_tokens=True, max_length=120)
17
+ with torch.no_grad():
18
+ out = enc(input_ids=inputs["input_ids"].to(device), attention_mask=inputs["attention_mask"].to(device))
19
+ emb = out['last_hidden_state'].detach() # [B, T, D]
20
+ return emb.to(dtype)
21
+
22
+ def build_pipeline():
23
+ device = "cuda" if torch.cuda.is_available() else "cpu"
24
+ dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
25
+ with open("configs/gpt_config.json","r") as f:
26
+ gpt_cfg = json.load(f)
27
+ with open("configs/vq_config.json","r") as f:
28
+ vq_cfg = json.load(f)
29
+ pipe = CondRefARPipeline.from_pretrained(".", gpt_cfg, vq_cfg, device=device, torch_dtype=dtype)
30
+ tok, enc = build_t5(device, dtype)
31
+ return pipe, tok, enc
32
+
33
+ pipe, tok, enc = build_pipeline()
34
+
35
+ def infer(prompt, control_image, cfg_scale, temperature, top_k, top_p):
36
+ emb = text_to_emb(prompt, tok, enc, pipe.device, pipe.dtype)
37
+ imgs = pipe(emb, control_image['composite'][:, :, :3], cfg_scale=cfg_scale, temperature=temperature, top_k=top_k, top_p=top_p)
38
+ return imgs[0]
39
+
40
+
41
+ EXAMPLES = [
42
+ [
43
+ "Aerial view of a large industrial area with multiple buildings and roads. There are several roads and highways visible in the image, and there are several parking lots scattered throughout the area.",
44
+ "assets/examples/example1.jpg",
45
+ 4.0, 1.0, 2000, 1.0,
46
+ ],
47
+ [
48
+ "Aaerial view of a forested area with a river running through it. On the right side of the image, there is a small town or village with a red-roofed building. ",
49
+ "assets/examples/example2.jpg",
50
+ 5.0, 0.95, 2500, 0.95,
51
+ ],
52
+ ]
53
+
54
+
55
+ with gr.Blocks(title="CondRef-AR", theme=gr.themes.Soft()) as demo:
56
+ gr.Markdown("## CondRef-AR: Controllable Aerial Image Generation")
57
+
58
+ with gr.Row(equal_height=True):
59
+ # 左侧:输入区
60
+ with gr.Column(scale=3):
61
+ prompt = gr.Textbox(label="Prompt", lines=2, placeholder="Describe the city...")
62
+ editor = gr.ImageEditor(
63
+ type="numpy", crop_size="1:1", canvas_size=(512, 512),
64
+ label="Image"
65
+ )
66
+ with gr.Row():
67
+ btn_gen = gr.Button("Generate", variant="primary")
68
+ btn_clear = gr.Button("Clear")
69
+
70
+ # 右侧:参数 + 输出 + 示例
71
+ with gr.Column(scale=2):
72
+ with gr.Accordion("Advanced settings", open=False):
73
+ cfg_scale = gr.Slider(1, 8, value=4, step=0.5, label="CFG scale")
74
+ temperature = gr.Slider(0.5, 1.5, value=1.0, step=0.05, label="Temperature")
75
+ top_k = gr.Slider(50, 4000, value=2000, step=50, label="top_k")
76
+ top_p = gr.Slider(0.5, 1.0, value=1.0, step=0.01, label="top_p")
77
+
78
+ output = gr.Image(type="pil", label="Result", height=512)
79
+
80
+ # 可点击示例:点击后自动填充并运行
81
+ gr.Examples(
82
+ examples=EXAMPLES,
83
+ inputs=[prompt, editor, cfg_scale, temperature, top_k, top_p],
84
+ outputs=output,
85
+ fn=infer,
86
+ cache_examples=False,
87
+ examples_per_page=2,
88
+ label="Examples"
89
+ )
90
+
91
+ # 按钮事件
92
+ btn_gen.click(
93
+ infer,
94
+ inputs=[prompt, editor, cfg_scale, temperature, top_k, top_p],
95
+ outputs=output
96
+ )
97
+ btn_clear.click(lambda: (None, None), outputs=[editor, output])
98
+
99
+ if __name__ == "__main__":
100
+ demo.launch()
assets/aerial_img.gif ADDED

Git LFS Details

  • SHA256: 05f2288e0bd745fb1d5fcc11150bd89008b4c00954c015db1c43e62c5c8f8723
  • Pointer size: 133 Bytes
  • Size of remote file: 13.7 MB
assets/control_img.gif ADDED

Git LFS Details

  • SHA256: 1dfe7fff7a07392dce819237c939d2d94035572207c5d8810cbca5039aa331a2
  • Pointer size: 132 Bytes
  • Size of remote file: 7.73 MB
assets/evolution.png ADDED

Git LFS Details

  • SHA256: 39e9ca004589b22970d99d692fff9b89c56f2b9bd3fe9d41bbc1bb1cbde77727
  • Pointer size: 132 Bytes
  • Size of remote file: 1.45 MB
assets/examples/example1.jpg ADDED
assets/examples/example2.jpg ADDED
assets/method.jpg ADDED

Git LFS Details

  • SHA256: fa759edd96abdc2411ea9c63a4879aed6bfacbef814aea15c76ebe1c75f2455a
  • Pointer size: 131 Bytes
  • Size of remote file: 222 kB
assets/samples.png ADDED

Git LFS Details

  • SHA256: 5a9c94f23cc1346e33cb67e2ae46b8033d5c3c299d55488fe7c9d090295ee3a6
  • Pointer size: 132 Bytes
  • Size of remote file: 2.37 MB
configs/gpt_config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "gpt_name": "GPT-XL",
3
+ "image_size": 512,
4
+ "downsample_size": 16,
5
+ "vocab_size": 16384,
6
+ "num_classes": 1000,
7
+ "cls_token_num": 120,
8
+ "model_type": "t2i",
9
+ "adapter_size": "small",
10
+ "condition_type": "sketch",
11
+ "dtype": "bfloat16"
12
+ }
configs/vq_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_name": "VQ-16",
3
+ "image_size": 512,
4
+ "downsample_size": 16,
5
+ "n_q": 8,
6
+ "codebook_size": 16384,
7
+ "codebook_embed_dim": 8,
8
+ "latent_channels": 8
9
+ }
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ transformers
4
+ safetensors
5
+ pillow
6
+ numpy
7
+ xformers
sample.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json, torch
2
+ from CondRefAR.pipeline import CondRefARPipeline
3
+ from transformers import AutoTokenizer, T5EncoderModel
4
+
5
+ device = "cuda" if torch.cuda.is_available() else "cpu"
6
+ dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
7
+
8
+ gpt_cfg = json.load(open("configs/gpt_config.json"))
9
+ vq_cfg = json.load(open("configs/vq_config.json"))
10
+ pipe = CondRefARPipeline.from_pretrained(".", gpt_cfg, vq_cfg, device=device, torch_dtype=dtype)
11
+
12
+ tok = AutoTokenizer.from_pretrained("google/flan-t5-xl")
13
+ enc = T5EncoderModel.from_pretrained("google/flan-t5-xl", torch_dtype=dtype).to(device).eval()
14
+
15
+ prompt = "Aaerial view of a forested area with a river running through it. On the right side of the image, there is a small town or village with a red-roofed building."
16
+ control = "assets/examples/example2.jpg"
17
+
18
+ from PIL import Image, ImageOps
19
+ control_img = Image.open(control).convert("RGB")
20
+
21
+ inputs = tok([prompt], return_tensors="pt", padding="max_length", truncation=True, max_length=120)
22
+ with torch.no_grad():
23
+ emb = enc(input_ids=inputs["input_ids"].to(device), attention_mask=inputs["attention_mask"].to(device)).last_hidden_state
24
+
25
+ imgs = pipe(emb, control_img, cfg_scale=4, temperature=1.0, top_k=2000, top_p=1.0)
26
+ imgs[0].save("sample.png")
weights/sketch-gpt-xl.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:936bf74f9d71ae61ab183d0b6dc133362e2109f73105d3000d371fe6c2d52f3b
3
+ size 3350054432
weights/vq-16.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4c98b63a8f1da5c325443a7372fdfeb0ca59037d55d31cfccc6b157041fb924e
3
+ size 287832820