import gradio as gr import torch import json from CondRefAR.pipeline import CondRefARPipeline from transformers import AutoTokenizer, T5EncoderModel # 简化:直接用 transformers 的 flan-t5-xl 提取文本嵌入 def build_t5(device, dtype): tok = AutoTokenizer.from_pretrained("google/flan-t5-xl") enc = T5EncoderModel.from_pretrained("google/flan-t5-xl", torch_dtype=dtype) enc = enc.to(device) enc.eval() return tok, enc def text_to_emb(prompt, tok, enc, device, dtype): inputs = tok([prompt], return_tensors="pt", padding='max_length', truncation=True, return_attention_mask=True, add_special_tokens=True, max_length=120) with torch.no_grad(): out = enc(input_ids=inputs["input_ids"].to(device), attention_mask=inputs["attention_mask"].to(device)) emb = out['last_hidden_state'].detach() # [B, T, D] return emb.to(dtype) def build_pipeline(): device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 with open("configs/gpt_config.json","r") as f: gpt_cfg = json.load(f) with open("configs/vq_config.json","r") as f: vq_cfg = json.load(f) pipe = CondRefARPipeline.from_pretrained(".", gpt_cfg, vq_cfg, device=device, torch_dtype=dtype) tok, enc = build_t5(device, dtype) return pipe, tok, enc pipe, tok, enc = build_pipeline() def infer(prompt, control_image, cfg_scale, temperature, top_k, top_p): emb = text_to_emb(prompt, tok, enc, pipe.device, pipe.dtype) imgs = pipe(emb, control_image['composite'][:, :, :3], cfg_scale=cfg_scale, temperature=temperature, top_k=top_k, top_p=top_p) return imgs[0] EXAMPLES = [ [ "Aerial view of a large industrial area with multiple buildings and roads. There are several roads and highways visible in the image, and there are several parking lots scattered throughout the area.", "assets/examples/example1.jpg", 4.0, 1.0, 2000, 1.0, ], [ "Aaerial view of a forested area with a river running through it. On the right side of the image, there is a small town or village with a red-roofed building. ", "assets/examples/example2.jpg", 5.0, 0.95, 2500, 0.95, ], ] with gr.Blocks(title="CondRef-AR", theme=gr.themes.Soft()) as demo: gr.Markdown("## CondRef-AR: Controllable Aerial Image Generation") with gr.Row(equal_height=True): # 左侧:输入区 with gr.Column(scale=3): prompt = gr.Textbox(label="Prompt", lines=2, placeholder="Describe the city...") editor = gr.ImageEditor( type="numpy", crop_size="1:1", canvas_size=(512, 512), label="Image" ) with gr.Row(): btn_gen = gr.Button("Generate", variant="primary") btn_clear = gr.Button("Clear") # 右侧:参数 + 输出 + 示例 with gr.Column(scale=2): with gr.Accordion("Advanced settings", open=False): cfg_scale = gr.Slider(1, 8, value=4, step=0.5, label="CFG scale") temperature = gr.Slider(0.5, 1.5, value=1.0, step=0.05, label="Temperature") top_k = gr.Slider(50, 4000, value=2000, step=50, label="top_k") top_p = gr.Slider(0.5, 1.0, value=1.0, step=0.01, label="top_p") output = gr.Image(type="pil", label="Result", height=512) # 可点击示例:点击后自动填充并运行 gr.Examples( examples=EXAMPLES, inputs=[prompt, editor, cfg_scale, temperature, top_k, top_p], outputs=output, fn=infer, cache_examples=False, examples_per_page=2, label="Examples" ) # 按钮事件 btn_gen.click( infer, inputs=[prompt, editor, cfg_scale, temperature, top_k, top_p], outputs=output ) btn_clear.click(lambda: (None, None), outputs=[editor, output]) if __name__ == "__main__": demo.launch()