Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +115 -0
- .github/workflows/logo.gif +3 -0
- .github/workflows/publish.yaml +29 -0
- .gitignore +175 -0
- .home/.modelscope/credentials/session +1 -0
- .swp +0 -0
- LICENSE +201 -0
- README.md +783 -0
- README_zh.md +784 -0
- assets/egg.mp4 +3 -0
- comp_attn_bbox_layout.png +0 -0
- comp_attn_trajectory.png +0 -0
- diffsynth/__init__.py +1 -0
- diffsynth/configs/__init__.py +2 -0
- diffsynth/configs/model_configs.py +518 -0
- diffsynth/configs/vram_management_module_maps.py +197 -0
- diffsynth/core/__init__.py +5 -0
- diffsynth/core/attention/__init__.py +1 -0
- diffsynth/core/attention/attention.py +121 -0
- diffsynth/core/data/__init__.py +1 -0
- diffsynth/core/data/operators.py +218 -0
- diffsynth/core/data/unified_dataset.py +112 -0
- diffsynth/core/gradient/__init__.py +1 -0
- diffsynth/core/gradient/gradient_checkpoint.py +34 -0
- diffsynth/core/loader/__init__.py +3 -0
- diffsynth/core/loader/config.py +117 -0
- diffsynth/core/loader/file.py +121 -0
- diffsynth/core/loader/model.py +79 -0
- diffsynth/core/vram/__init__.py +2 -0
- diffsynth/core/vram/disk_map.py +93 -0
- diffsynth/core/vram/initialization.py +21 -0
- diffsynth/core/vram/layers.py +475 -0
- diffsynth/datasets/mvdataset.py +393 -0
- diffsynth/diffusion/__init__.py +6 -0
- diffsynth/diffusion/base_pipeline.py +439 -0
- diffsynth/diffusion/flow_match.py +179 -0
- diffsynth/diffusion/logger.py +43 -0
- diffsynth/diffusion/loss.py +119 -0
- diffsynth/diffusion/parsers.py +70 -0
- diffsynth/diffusion/runner.py +129 -0
- diffsynth/diffusion/training_module.py +212 -0
- diffsynth/models/comp_attn_model.py +592 -0
- diffsynth/models/dinov3_image_encoder.py +94 -0
- diffsynth/models/flux2_dit.py +1057 -0
- diffsynth/models/flux2_text_encoder.py +58 -0
- diffsynth/models/flux2_vae.py +0 -0
- diffsynth/models/flux_controlnet.py +384 -0
- diffsynth/models/flux_dit.py +395 -0
- diffsynth/models/flux_infiniteyou.py +129 -0
- diffsynth/models/flux_ipadapter.py +110 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,118 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
.github/workflows/logo.gif filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
assets/egg.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
examples/Comp-Attn.pdf filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
examples/InstanceV.pdf filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
examples/wanvideo/model_training/egg_statemachine_dataset/egg_4fps_640x320.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
examples/wanvideo/model_training/egg_statemachine_dataset/egg_5fps_1280x720.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
examples/wanvideo/model_training/egg_statemachine_dataset/egg_5fps_640x320.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
examples/wanvideo/model_training/egg_statemachine_dataset/egg_5fps_NonexNone.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
examples/wanvideo/model_training/egg_statemachine_dataset/egg_8fps_448x256.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
output/_smoke_352x640/p03_seed2_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-as-the.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
output/_smoke_352x640/p07_seed6_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-on-a-mo.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
output/_smoke_352x640/p08_seed7_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-in-a-vi.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 48 |
+
output/_smoke_352x640/p09_seed8_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-as-the.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 49 |
+
output/_smoke_352x640/p10_seed9_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-near-a.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 50 |
+
output/wan2.1-1.3b-mc-lora/p01_seed0_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-inside.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 51 |
+
output/wan2.1-1.3b-mc-lora/p02_seed1_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-at-a-vi.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 52 |
+
output/wan2.1-1.3b-mc-lora/p03_seed2_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-in-fron.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 53 |
+
output/wan2.1-1.3b-mc-lora/p04_seed3_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-next-to.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 54 |
+
output/wan2.1-1.3b-mc-lora/p05_seed4_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-inside.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 55 |
+
output/wan2.1-1.3b-mc-lora/p06_seed5_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-at-a-ca.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 56 |
+
output/wan2.1-1.3b-mc-lora/p07_seed6_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-in-a-vi.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 57 |
+
output/wan2.1-1.3b-mc-lora/p08_seed7_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-at-the.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 58 |
+
output/wan2.1-1.3b-mc-lora/p09_seed8_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-in-a-sm.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 59 |
+
output/wan2.1-1.3b-mc-lora/p10_seed9_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-on-a-vi.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 60 |
+
output/wan2.1-1.3b-mc-lora_352x640_stable/p01_seed0_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-charact.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 61 |
+
output/wan2.1-1.3b-mc-lora_352x640_stable/p02_seed1_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-where-t.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 62 |
+
output/wan2.1-1.3b-mc-lora_352x640_stable/p03_seed2_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-as-the.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 63 |
+
output/wan2.1-1.3b-mc-lora_352x640_stable/p04_seed3_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-set-in.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 64 |
+
output/wan2.1-1.3b-mc-lora_352x640_stable/p05_seed4_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-during.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 65 |
+
output/wan2.1-1.3b-mc-lora_352x640_stable/p06_seed5_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-as-the.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 66 |
+
output/wan2.1-1.3b-mc-lora_352x640_stable/p07_seed6_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-on-a-mo.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 67 |
+
output/wan2.1-1.3b-mc-lora_352x640_stable/p08_seed7_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-in-a-vi.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 68 |
+
output/wan2.1-1.3b-mc-lora_352x640_stable/p09_seed8_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-as-the.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 69 |
+
output/wan2.1-1.3b-mc-lora_352x640_stable/p10_seed9_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-near-a.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 70 |
+
output/wan2.1-1.3b-mc-lora_batch10/p01_seed0_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-charact.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 71 |
+
output/wan2.1-1.3b-mc-lora_batch10/p02_seed1_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-where-t.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 72 |
+
output/wan2.1-1.3b-mc-lora_batch10/p03_seed2_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-as-the.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 73 |
+
output/wan2.1-1.3b-mc-lora_batch10/p04_seed3_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-set-in.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 74 |
+
output/wan2.1-1.3b-mc-lora_batch10/p05_seed4_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-during.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 75 |
+
output/wan2.1-1.3b-mc-lora_batch10/p06_seed5_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-as-the.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 76 |
+
output/wan2.1-1.3b-mc-lora_batch10/p07_seed6_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-on-a-mo.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 77 |
+
output/wan2.1-1.3b-mc-lora_batch10/p08_seed7_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-in-a-vi.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 78 |
+
output/wan2.1-1.3b-mc-lora_batch10/p09_seed8_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-as-the.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 79 |
+
output/wan2.1-1.3b-mc-lora_batch10/p10_seed9_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-near-a.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 80 |
+
output/wan2.1-1.3b-mc-lora_epoch4_1/p01_seed0_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-charact.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 81 |
+
output/wan2.1-1.3b-mc-lora_epoch4_1/p02_seed1_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-where-t.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 82 |
+
output/wan2.1-1.3b-mc-lora_epoch4_1/p03_seed2_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-as-the.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 83 |
+
output/wan2.1-1.3b-mc-lora_epoch4_1/p04_seed3_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-set-in.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 84 |
+
output/wan2.1-1.3b-mc-lora_epoch4_1/p05_seed4_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-during.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 85 |
+
output/wan2.1-1.3b-mc-lora_epoch4_1/p06_seed5_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-as-the.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 86 |
+
output/wan2.1-1.3b-mc-lora_epoch4_1/p07_seed6_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-on-a-mo.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 87 |
+
output/wan2.1-1.3b-mc-lora_epoch4_1/p08_seed7_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-in-a-vi.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 88 |
+
output/wan2.1-1.3b-mc-lora_epoch4_1/p09_seed8_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-as-the.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 89 |
+
output/wan2.1-1.3b-mc-lora_epoch4_1/p10_seed9_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-near-a.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 90 |
+
output/wan2.1-1.3b-statemachine-egg/egg_statemachine_infer.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 91 |
+
output/wan2.1-1.3b-statemachine-egg_cooked2raw/egg_statemachine_infer.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 92 |
+
output/wan2.1-1.3b-statemachine-egg_moveup_long/egg_statemachine_infer.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 93 |
+
output/wan2.1-1.3b-statemachine-egg_moveup_long20/egg_statemachine_infer.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 94 |
+
output/wan2.1-1.3b-statemachine-egg_moveup_long20_promptclean/egg_statemachine_infer.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 95 |
+
outputs/instancev/boat_seagull_20260105_114652.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 96 |
+
outputs/instancev/deer_approach_20260105_114652.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 97 |
+
outputs/instancev/dog_running.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 98 |
+
outputs/instancev/dog_running_baseline.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 99 |
+
outputs/instancev/four_people_talking.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 100 |
+
outputs/instancev/four_pigeons_orbit_20260105_114652.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 101 |
+
outputs/instancev/multi_instances_animals.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 102 |
+
outputs/instancev/single_car_sweep.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 103 |
+
outputs/instancev/three_diagonal_motion.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 104 |
+
outputs/instancev/two_crossing_athletes.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 105 |
+
outputs/instancev/two_people_talking.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 106 |
+
outputs/instancev/two_scooters_crossing_20260105_112906.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 107 |
+
outputs/instancev/two_scooters_crossing_20260105_113836.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 108 |
+
outputs/instancev/two_scooters_crossing_20260105_114652.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 109 |
+
outputs/instancev/two_students_drone_20260105_114652.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 110 |
+
outputs/instancev-new/case_00.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 111 |
+
outputs/instancev-new/case_01.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 112 |
+
outputs/instancev-new/case_02.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 113 |
+
outputs/instancev-new/case_03.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 114 |
+
outputs/instancev-new/case_04.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 115 |
+
outputs/instancev-new/case_05.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 116 |
+
outputs/instancev-new/case_06.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 117 |
+
outputs/instancev-new/case_07.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 118 |
+
outputs/instancev-new/case_08.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 119 |
+
outputs/instancev-new/case_09.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 120 |
+
outputs/instancev-new/case_10.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 121 |
+
outputs/instancev-new/case_11.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 122 |
+
outputs/instancev-new/case_12.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 123 |
+
outputs/instancev-new/case_13.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 124 |
+
outputs/instancev-new/case_14.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 125 |
+
outputs/instancev-new/case_15.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 126 |
+
outputs/instancev-new/case_16.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 127 |
+
outputs/instancev-new/case_17.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 128 |
+
outputs/instancev-new/case_18.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 129 |
+
outputs/instancev-new/case_19.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 130 |
+
outputs/instancev_iground_infer_20260107_024437.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 131 |
+
outputs/instancev_iground_test.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 132 |
+
video.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 133 |
+
video_1_Wan2.1-T2V-1.3B-LoRA2.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 134 |
+
video_1_Wan2.1-T2V-1.3B-LoRA_epoch1.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 135 |
+
video_1_Wan2.1-T2V-1.3B.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 136 |
+
video_1_Wan2.1-T2V-1.3B_LoRA.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 137 |
+
video_comp_attn_pipeline[[:space:]]copy.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 138 |
+
video_comp_attn_pipeline.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 139 |
+
wandb/run-20251211_101851-syoqkmhy/run-syoqkmhy.wandb filter=lfs diff=lfs merge=lfs -text
|
| 140 |
+
wandb/run-20251211_172331-jxaicuod/run-jxaicuod.wandb filter=lfs diff=lfs merge=lfs -text
|
| 141 |
+
wandb/run-20251225_172459-gjtz0um5/run-gjtz0um5.wandb filter=lfs diff=lfs merge=lfs -text
|
| 142 |
+
wandb/run-20251225_214534-3dh8lbav/run-3dh8lbav.wandb filter=lfs diff=lfs merge=lfs -text
|
| 143 |
+
wandb/run-20251229_100816-zirij84a/run-zirij84a.wandb filter=lfs diff=lfs merge=lfs -text
|
| 144 |
+
wandb/run-20260102_054910-38oaloji/run-38oaloji.wandb filter=lfs diff=lfs merge=lfs -text
|
| 145 |
+
wandb/run-20260102_104929-zd02vtce/run-zd02vtce.wandb filter=lfs diff=lfs merge=lfs -text
|
| 146 |
+
wandb/run-20260102_162705-mr7vgtqn/run-mr7vgtqn.wandb filter=lfs diff=lfs merge=lfs -text
|
| 147 |
+
wandb/run-20260103_090415-36yjbun5/run-36yjbun5.wandb filter=lfs diff=lfs merge=lfs -text
|
| 148 |
+
wandb/run-20260103_115016-kurow4tk/run-kurow4tk.wandb filter=lfs diff=lfs merge=lfs -text
|
| 149 |
+
wandb/run-20260106_030539-rupbhtts/run-rupbhtts.wandb filter=lfs diff=lfs merge=lfs -text
|
| 150 |
+
wandb/run-20260110_110203-bl4gd6wi/run-bl4gd6wi.wandb filter=lfs diff=lfs merge=lfs -text
|
.github/workflows/logo.gif
ADDED
|
Git LFS Details
|
.github/workflows/publish.yaml
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: release
|
| 2 |
+
|
| 3 |
+
on:
|
| 4 |
+
push:
|
| 5 |
+
tags:
|
| 6 |
+
- 'v**'
|
| 7 |
+
|
| 8 |
+
concurrency:
|
| 9 |
+
group: ${{ github.workflow }}-${{ github.ref }}-publish
|
| 10 |
+
cancel-in-progress: true
|
| 11 |
+
|
| 12 |
+
jobs:
|
| 13 |
+
build-n-publish:
|
| 14 |
+
runs-on: ubuntu-20.04
|
| 15 |
+
#if: startsWith(github.event.ref, 'refs/tags')
|
| 16 |
+
steps:
|
| 17 |
+
- uses: actions/checkout@v2
|
| 18 |
+
- name: Set up Python 3.10
|
| 19 |
+
uses: actions/setup-python@v2
|
| 20 |
+
with:
|
| 21 |
+
python-version: '3.10'
|
| 22 |
+
- name: Install wheel
|
| 23 |
+
run: pip install wheel==0.44.0 && pip install -r requirements.txt
|
| 24 |
+
- name: Build DiffSynth
|
| 25 |
+
run: python setup.py sdist bdist_wheel
|
| 26 |
+
- name: Publish package to PyPI
|
| 27 |
+
run: |
|
| 28 |
+
pip install twine
|
| 29 |
+
twine upload dist/* --skip-existing -u __token__ -p ${{ secrets.PYPI_API_TOKEN }}
|
.gitignore
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/data
|
| 2 |
+
/models
|
| 3 |
+
/scripts
|
| 4 |
+
/diffusers
|
| 5 |
+
*.pkl
|
| 6 |
+
*.safetensors
|
| 7 |
+
*.pth
|
| 8 |
+
*.ckpt
|
| 9 |
+
*.pt
|
| 10 |
+
*.bin
|
| 11 |
+
*.DS_Store
|
| 12 |
+
*.msc
|
| 13 |
+
*.mv
|
| 14 |
+
log*.txt
|
| 15 |
+
|
| 16 |
+
# Byte-compiled / optimized / DLL files
|
| 17 |
+
__pycache__/
|
| 18 |
+
*.py[cod]
|
| 19 |
+
*$py.class
|
| 20 |
+
|
| 21 |
+
# C extensions
|
| 22 |
+
*.so
|
| 23 |
+
|
| 24 |
+
# Distribution / packaging
|
| 25 |
+
.Python
|
| 26 |
+
build/
|
| 27 |
+
develop-eggs/
|
| 28 |
+
dist/
|
| 29 |
+
downloads/
|
| 30 |
+
eggs/
|
| 31 |
+
.eggs/
|
| 32 |
+
lib/
|
| 33 |
+
lib64/
|
| 34 |
+
parts/
|
| 35 |
+
sdist/
|
| 36 |
+
var/
|
| 37 |
+
wheels/
|
| 38 |
+
share/python-wheels/
|
| 39 |
+
*.egg-info/
|
| 40 |
+
.installed.cfg
|
| 41 |
+
*.egg
|
| 42 |
+
MANIFEST
|
| 43 |
+
|
| 44 |
+
# PyInstaller
|
| 45 |
+
# Usually these files are written by a python script from a template
|
| 46 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 47 |
+
*.manifest
|
| 48 |
+
*.spec
|
| 49 |
+
|
| 50 |
+
# Installer logs
|
| 51 |
+
pip-log.txt
|
| 52 |
+
pip-delete-this-directory.txt
|
| 53 |
+
|
| 54 |
+
# Unit test / coverage reports
|
| 55 |
+
htmlcov/
|
| 56 |
+
.tox/
|
| 57 |
+
.nox/
|
| 58 |
+
.coverage
|
| 59 |
+
.coverage.*
|
| 60 |
+
.cache
|
| 61 |
+
nosetests.xml
|
| 62 |
+
coverage.xml
|
| 63 |
+
*.cover
|
| 64 |
+
*.py,cover
|
| 65 |
+
.hypothesis/
|
| 66 |
+
.pytest_cache/
|
| 67 |
+
cover/
|
| 68 |
+
|
| 69 |
+
# Translations
|
| 70 |
+
*.mo
|
| 71 |
+
*.pot
|
| 72 |
+
|
| 73 |
+
# Django stuff:
|
| 74 |
+
*.log
|
| 75 |
+
local_settings.py
|
| 76 |
+
db.sqlite3
|
| 77 |
+
db.sqlite3-journal
|
| 78 |
+
|
| 79 |
+
# Flask stuff:
|
| 80 |
+
instance/
|
| 81 |
+
.webassets-cache
|
| 82 |
+
|
| 83 |
+
# Scrapy stuff:
|
| 84 |
+
.scrapy
|
| 85 |
+
|
| 86 |
+
# Sphinx documentation
|
| 87 |
+
docs/_build/
|
| 88 |
+
|
| 89 |
+
# PyBuilder
|
| 90 |
+
.pybuilder/
|
| 91 |
+
target/
|
| 92 |
+
|
| 93 |
+
# Jupyter Notebook
|
| 94 |
+
.ipynb_checkpoints
|
| 95 |
+
|
| 96 |
+
# IPython
|
| 97 |
+
profile_default/
|
| 98 |
+
ipython_config.py
|
| 99 |
+
|
| 100 |
+
# pyenv
|
| 101 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 102 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 103 |
+
# .python-version
|
| 104 |
+
|
| 105 |
+
# pipenv
|
| 106 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 107 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 108 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 109 |
+
# install all needed dependencies.
|
| 110 |
+
#Pipfile.lock
|
| 111 |
+
|
| 112 |
+
# poetry
|
| 113 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 114 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 115 |
+
# commonly ignored for libraries.
|
| 116 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 117 |
+
#poetry.lock
|
| 118 |
+
|
| 119 |
+
# pdm
|
| 120 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 121 |
+
#pdm.lock
|
| 122 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
| 123 |
+
# in version control.
|
| 124 |
+
# https://pdm.fming.dev/#use-with-ide
|
| 125 |
+
.pdm.toml
|
| 126 |
+
|
| 127 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 128 |
+
__pypackages__/
|
| 129 |
+
|
| 130 |
+
# Celery stuff
|
| 131 |
+
celerybeat-schedule
|
| 132 |
+
celerybeat.pid
|
| 133 |
+
|
| 134 |
+
# SageMath parsed files
|
| 135 |
+
*.sage.py
|
| 136 |
+
|
| 137 |
+
# Environments
|
| 138 |
+
.env
|
| 139 |
+
.venv
|
| 140 |
+
env/
|
| 141 |
+
venv/
|
| 142 |
+
ENV/
|
| 143 |
+
env.bak/
|
| 144 |
+
venv.bak/
|
| 145 |
+
|
| 146 |
+
# Spyder project settings
|
| 147 |
+
.spyderproject
|
| 148 |
+
.spyproject
|
| 149 |
+
|
| 150 |
+
# Rope project settings
|
| 151 |
+
.ropeproject
|
| 152 |
+
|
| 153 |
+
# mkdocs documentation
|
| 154 |
+
/site
|
| 155 |
+
|
| 156 |
+
# mypy
|
| 157 |
+
.mypy_cache/
|
| 158 |
+
.dmypy.json
|
| 159 |
+
dmypy.json
|
| 160 |
+
|
| 161 |
+
# Pyre type checker
|
| 162 |
+
.pyre/
|
| 163 |
+
|
| 164 |
+
# pytype static type analyzer
|
| 165 |
+
.pytype/
|
| 166 |
+
|
| 167 |
+
# Cython debug symbols
|
| 168 |
+
cython_debug/
|
| 169 |
+
|
| 170 |
+
# PyCharm
|
| 171 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 172 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 173 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 174 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 175 |
+
#.idea/
|
.home/.modelscope/credentials/session
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
13921be3c1924b38a5a21db02dce6b94
|
.swp
ADDED
|
Binary file (12.3 kB). View file
|
|
|
LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [2023] [Zhongjie Duan]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
README.md
ADDED
|
@@ -0,0 +1,783 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DiffSynth-Studio
|
| 2 |
+
|
| 3 |
+
<a href="https://github.com/modelscope/DiffSynth-Studio"><img src=".github/workflows/logo.gif" title="Logo" style="max-width:100%;" width="55" /></a> <a href="https://trendshift.io/repositories/10946" target="_blank"><img src="https://trendshift.io/api/badge/repositories/10946" alt="modelscope%2FDiffSynth-Studio | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a></p>
|
| 4 |
+
|
| 5 |
+
[](https://pypi.org/project/DiffSynth/)
|
| 6 |
+
[](https://github.com/modelscope/DiffSynth-Studio/blob/master/LICENSE)
|
| 7 |
+
[](https://github.com/modelscope/DiffSynth-Studio/issues)
|
| 8 |
+
[](https://GitHub.com/modelscope/DiffSynth-Studio/pull/)
|
| 9 |
+
[](https://GitHub.com/modelscope/DiffSynth-Studio/commit/)
|
| 10 |
+
|
| 11 |
+
[切换到中文版](./README_zh.md)
|
| 12 |
+
|
| 13 |
+
## Introduction
|
| 14 |
+
|
| 15 |
+
Welcome to the magical world of Diffusion models! DiffSynth-Studio is an open-source Diffusion model engine developed and maintained by the [ModelScope Community](https://www.modelscope.cn/). We hope to foster technological innovation through framework construction, aggregate the power of the open-source community, and explore the boundaries of generative model technology!
|
| 16 |
+
|
| 17 |
+
DiffSynth currently includes two open-source projects:
|
| 18 |
+
* [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio): Focused on aggressive technical exploration, targeting academia, and providing cutting-edge model capability support.
|
| 19 |
+
* [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine): Focused on stable model deployment, targeting industry, and providing higher computational performance and more stable features.
|
| 20 |
+
|
| 21 |
+
[DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio) and [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine) are the core engines of the ModelScope AIGC zone. Welcome to experience our carefully crafted productized features:
|
| 22 |
+
|
| 23 |
+
* ModelScope AIGC Zone (for Chinese users): https://modelscope.cn/aigc/home
|
| 24 |
+
* ModelScope Civision (for global users): https://modelscope.ai/civision/home
|
| 25 |
+
|
| 26 |
+
> DiffSynth-Studio Documentation: [中文版](/docs/zh/README.md)、[English version](/docs/en/README.md)
|
| 27 |
+
|
| 28 |
+
We believe that a well-developed open-source code framework can lower the threshold for technical exploration. We have achieved many [interesting technologies](#innovative-achievements) based on this codebase. Perhaps you also have many wild ideas, and with DiffSynth-Studio, you can quickly realize these ideas. For this reason, we have prepared detailed documentation for developers. We hope that through these documents, developers can understand the principles of Diffusion models, and we look forward to expanding the boundaries of technology together with you.
|
| 29 |
+
|
| 30 |
+
## Update History
|
| 31 |
+
|
| 32 |
+
> DiffSynth-Studio has undergone major version updates, and some old features are no longer maintained. If you need to use old features, please switch to the [last historical version](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3) before the major version update.
|
| 33 |
+
|
| 34 |
+
> Currently, the development personnel of this project are limited, with most of the work handled by [Artiprocher](https://github.com/Artiprocher). Therefore, the progress of new feature development will be relatively slow, and the speed of responding to and resolving issues is limited. We apologize for this and ask developers to understand.
|
| 35 |
+
|
| 36 |
+
- **December 9, 2025** We release a wild model based on DiffSynth-Studio 2.0: [Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L) (Image-to-LoRA). This model takes an image as input and outputs a LoRA. Although this version still has significant room for improvement in terms of generalization, detail preservation, and other aspects, we are open-sourcing these models to inspire more innovative research.
|
| 37 |
+
|
| 38 |
+
- **December 4, 2025** DiffSynth-Studio 2.0 released! Many new features online
|
| 39 |
+
- [Documentation](/docs/en/README.md) online: Our documentation is still continuously being optimized and updated
|
| 40 |
+
- [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md) module upgraded, supporting layer-level disk offload, releasing both memory and VRAM simultaneously
|
| 41 |
+
- New model support
|
| 42 |
+
- Z-Image Turbo: [Model](https://www.modelscope.ai/models/Tongyi-MAI/Z-Image-Turbo), [Documentation](/docs/en/Model_Details/Z-Image.md), [Code](/examples/z_image/)
|
| 43 |
+
- FLUX.2-dev: [Model](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev), [Documentation](/docs/en/Model_Details/FLUX2.md), [Code](/examples/flux2/)
|
| 44 |
+
- Training framework upgrade
|
| 45 |
+
- [Split Training](/docs/zh/Training/Split_Training.md): Supports automatically splitting the training process into two stages: data processing and training (even for training ControlNet or any other model). Computations that do not require gradient backpropagation, such as text encoding and VAE encoding, are performed during the data processing stage, while other computations are handled during the training stage. Faster speed, less VRAM requirement.
|
| 46 |
+
- [Differential LoRA Training](/docs/zh/Training/Differential_LoRA.md): This is a training technique we used in [ArtAug](https://www.modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1), now available for LoRA training of any model.
|
| 47 |
+
- [FP8 Training](/docs/zh/Training/FP8_Precision.md): FP8 can be applied to any non-training model during training, i.e., models with gradients turned off or gradients that only affect LoRA weights.
|
| 48 |
+
|
| 49 |
+
<details>
|
| 50 |
+
<summary>More</summary>
|
| 51 |
+
|
| 52 |
+
- **November 4, 2025** Supported the [ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B) model, which is trained based on Wan 2.1 and supports generating corresponding actions based on reference videos.
|
| 53 |
+
|
| 54 |
+
- **October 30, 2025** Supported the [meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video) model, which supports text-to-video, image-to-video, and video continuation. This model uses the Wan framework for inference and training in this project.
|
| 55 |
+
|
| 56 |
+
- **October 27, 2025** Supported the [krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video) model, adding another member to the Wan model ecosystem.
|
| 57 |
+
|
| 58 |
+
- **September 23, 2025** [DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster) released! This model was jointly developed and open-sourced by us and Taobao Experience Design Team. Built upon Qwen-Image, the model is specifically designed for e-commerce poster scenarios, supporting precise partition layout control. Please refer to [our sample code](./examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py).
|
| 59 |
+
|
| 60 |
+
- **September 9, 2025** Our training framework supports various training modes. Currently adapted for Qwen-Image, in addition to the standard SFT training mode, Direct Distill is now supported. Please refer to [our sample code](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh). This feature is experimental, and we will continue to improve it to support more comprehensive model training functions.
|
| 61 |
+
|
| 62 |
+
- **August 28, 2025** We support Wan2.2-S2V, an audio-driven cinematic video generation model. See [./examples/wanvideo/](./examples/wanvideo/).
|
| 63 |
+
|
| 64 |
+
- **August 21, 2025** [DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2) released! Compared to the V1 version, the training dataset has been changed to [Qwen-Image-Self-Generated-Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Qwen-Image-Self-Generated-Dataset), so the generated images better conform to Qwen-Image's own image distribution and style. Please refer to [our sample code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py).
|
| 65 |
+
|
| 66 |
+
- **August 21, 2025** We open-sourced the [DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union) structural control LoRA model, adopting the In Context technical route, supporting multiple categories of structural control conditions, including canny, depth, lineart, softedge, normal, and openpose. Please refer to [our sample code](./examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py).
|
| 67 |
+
|
| 68 |
+
- **August 20, 2025** We open-sourced the [DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix) model, improving the editing effect of Qwen-Image-Edit on low-resolution image inputs. Please refer to [our sample code](./examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py)
|
| 69 |
+
|
| 70 |
+
- **August 19, 2025** 🔥 Qwen-Image-Edit open-sourced, welcome a new member to the image editing model family!
|
| 71 |
+
|
| 72 |
+
- **August 18, 2025** We trained and open-sourced the Qwen-Image inpainting ControlNet model [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint). The model structure adopts a lightweight design. Please refer to [our sample code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py).
|
| 73 |
+
|
| 74 |
+
- **August 15, 2025** We open-sourced the [Qwen-Image-Self-Generated-Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Qwen-Image-Self-Generated-Dataset) dataset. This is an image dataset generated using the Qwen-Image model, containing 160,000 `1024 x 1024` images. It includes general, English text rendering, and Chinese text rendering subsets. We provide annotations for image descriptions, entities, and structural control images for each image. Developers can use this dataset to train Qwen-Image models' ControlNet and EliGen models. We aim to promote technological development through open-sourcing!
|
| 75 |
+
|
| 76 |
+
- **August 13, 2025** We trained and open-sourced the Qwen-Image ControlNet model [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth). The model structure adopts a lightweight design. Please refer to [our sample code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py).
|
| 77 |
+
|
| 78 |
+
- **August 12, 2025** We trained and open-sourced the Qwen-Image ControlNet model [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny). The model structure adopts a lightweight design. Please refer to [our sample code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py).
|
| 79 |
+
|
| 80 |
+
- **August 11, 2025** We open-sourced the distilled acceleration model [DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA) for Qwen-Image, following the same training process as [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full), but the model structure has been modified to LoRA, thus being better compatible with other open-source ecosystem models.
|
| 81 |
+
|
| 82 |
+
- **August 7, 2025** We open-sourced the entity control LoRA model [DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen) for Qwen-Image. Qwen-Image-EliGen can achieve entity-level controlled text-to-image generation. Technical details can be found in [the paper](https://arxiv.org/abs/2501.01097). Training dataset: [EliGenTrainSet](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet).
|
| 83 |
+
|
| 84 |
+
- **August 5, 2025** We open-sourced the distilled acceleration model [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full) for Qwen-Image, achieving approximately 5x acceleration.
|
| 85 |
+
|
| 86 |
+
- **August 4, 2025** 🔥 Qwen-Image open-sourced, welcome a new member to the image generation model family!
|
| 87 |
+
|
| 88 |
+
- **August 1, 2025** [FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev) open-sourced, a text-to-image model focused on aesthetic photography. We provided comprehensive support in a timely manner, including low VRAM layer-by-layer offload, LoRA training, and full training. For more details, please refer to [./examples/flux/](./examples/flux/).
|
| 89 |
+
|
| 90 |
+
- **July 28, 2025** Wan 2.2 open-sourced. We provided comprehensive support in a timely manner, including low VRAM layer-by-layer offload, FP8 quantization, sequence parallelism, LoRA training, and full training. For more details, please refer to [./examples/wanvideo/](./examples/wanvideo/).
|
| 91 |
+
|
| 92 |
+
- **July 11, 2025** We propose Nexus-Gen, a unified framework that combines the language reasoning capabilities of Large Language Models (LLMs) with the image generation capabilities of diffusion models. This framework supports seamless image understanding, generation, and editing tasks.
|
| 93 |
+
- Paper: [Nexus-Gen: Unified Image Understanding, Generation, and Editing via Prefilled Autoregression in Shared Embedding Space](https://arxiv.org/pdf/2504.21356)
|
| 94 |
+
- GitHub Repository: https://github.com/modelscope/Nexus-Gen
|
| 95 |
+
- Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2), [HuggingFace](https://huggingface.co/modelscope/Nexus-GenV2)
|
| 96 |
+
- Training Dataset: [ModelScope Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Nexus-Gen-Training-Dataset)
|
| 97 |
+
- Online Experience: [ModelScope Nexus-Gen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/Nexus-Gen)
|
| 98 |
+
|
| 99 |
+
- **June 15, 2025** ModelScope's official evaluation framework [EvalScope](https://github.com/modelscope/evalscope) now supports text-to-image generation evaluation. Please refer to the [best practices](https://evalscope.readthedocs.io/zh-cn/latest/best_practice/t2i_eval.html) guide to try it out.
|
| 100 |
+
|
| 101 |
+
- **March 25, 2025** Our new open-source project [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine) is now open-sourced! Focused on stable model deployment, targeting industry, providing better engineering support, higher computational performance, and more stable features.
|
| 102 |
+
|
| 103 |
+
- **March 31, 2025** We support InfiniteYou, a face feature preservation method for FLUX. More details can be found in [./examples/InfiniteYou/](./examples/InfiniteYou/).
|
| 104 |
+
|
| 105 |
+
- **March 13, 2025** We support HunyuanVideo-I2V, the image-to-video generation version of Tencent's open-source HunyuanVideo. More details can be found in [./examples/HunyuanVideo/](./examples/HunyuanVideo/).
|
| 106 |
+
|
| 107 |
+
- **February 25, 2025** We support Wan-Video, a series of state-of-the-art video synthesis models open-sourced by Alibaba. See [./examples/wanvideo/](./examples/wanvideo/).
|
| 108 |
+
|
| 109 |
+
- **February 17, 2025** We support [StepVideo](https://modelscope.cn/models/stepfun-ai/stepvideo-t2v/summary)! Advanced video synthesis model! See [./examples/stepvideo](./examples/stepvideo/).
|
| 110 |
+
|
| 111 |
+
- **December 31, 2024** We propose EliGen, a new framework for entity-level controlled text-to-image generation, supplemented with an inpainting fusion pipeline, extending its capabilities to image inpainting tasks. EliGen can seamlessly integrate existing community models such as IP-Adapter and In-Context LoRA, enhancing their versatility. For more details, see [./examples/EntityControl](./examples/EntityControl/).
|
| 112 |
+
- Paper: [EliGen: Entity-Level Controlled Image Generation with Regional Attention](https://arxiv.org/abs/2501.01097)
|
| 113 |
+
- Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen), [HuggingFace](https://huggingface.co/modelscope/EliGen)
|
| 114 |
+
- Online Experience: [ModelScope EliGen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/EliGen)
|
| 115 |
+
- Training Dataset: [EliGen Train Set](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)
|
| 116 |
+
|
| 117 |
+
- **December 19, 2024** We implemented advanced VRAM management for HunyuanVideo, enabling video generation with resolutions of 129x720x1280 on 24GB VRAM or 129x512x384 on just 6GB VRAM. More details can be found in [./examples/HunyuanVideo/](./examples/HunyuanVideo/).
|
| 118 |
+
|
| 119 |
+
- **December 18, 2024** We propose ArtAug, a method to improve text-to-image models through synthesis-understanding interaction. We trained an ArtAug enhancement module for FLUX.1-dev in LoRA format. This model incorporates the aesthetic understanding of Qwen2-VL-72B into FLUX.1-dev, thereby improving the quality of generated images.
|
| 120 |
+
- Paper: https://arxiv.org/abs/2412.12888
|
| 121 |
+
- Example: https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/ArtAug
|
| 122 |
+
- Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1), [HuggingFace](https://huggingface.co/ECNU-CILab/ArtAug-lora-FLUX.1dev-v1)
|
| 123 |
+
- Demo: [ModelScope](https://modelscope.cn/aigc/imageGeneration?tab=advanced&versionId=7228&modelType=LoRA&sdVersion=FLUX_1&modelUrl=modelscope%3A%2F%2FDiffSynth-Studio%2FArtAug-lora-FLUX.1dev-v1%3Frevision%3Dv1.0), HuggingFace (coming soon)
|
| 124 |
+
|
| 125 |
+
- **October 25, 2024** We provide extensive FLUX ControlNet support. This project supports many different ControlNet models and can be freely combined, even if their structures are different. Additionally, ControlNet models are compatible with high-resolution optimization and partition control technologies, enabling very powerful controllable image generation. See [`./examples/ControlNet/`](./examples/ControlNet/).
|
| 126 |
+
|
| 127 |
+
- **October 8, 2024** We released extended LoRAs based on CogVideoX-5B and ExVideo. You can download this model from [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1) or [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1).
|
| 128 |
+
|
| 129 |
+
- **August 22, 2024** This project now supports CogVideoX-5B. See [here](/examples/video_synthesis/). We provide several interesting features for this text-to-video model, including:
|
| 130 |
+
- Text-to-video
|
| 131 |
+
- Video editing
|
| 132 |
+
- Self super-resolution
|
| 133 |
+
- Video interpolation
|
| 134 |
+
|
| 135 |
+
- **August 22, 2024** We implemented an interesting brush feature that supports all text-to-image models. Now you can create stunning images with the assistance of AI using the brush!
|
| 136 |
+
- Use it in our [WebUI](#usage-in-webui).
|
| 137 |
+
|
| 138 |
+
- **August 21, 2024** DiffSynth-Studio now supports FLUX.
|
| 139 |
+
- Enable CFG and high-resolution inpainting to improve visual quality. See [here](/examples/image_synthesis/README.md)
|
| 140 |
+
- LoRA, ControlNet, and other addon models will be released soon.
|
| 141 |
+
|
| 142 |
+
- **June 21, 2024** We propose ExVideo, a post-training fine-tuning technique aimed at enhancing the capabilities of video generation models. We extended Stable Video Diffusion to achieve long video generation of up to 128 frames.
|
| 143 |
+
- [Project Page](https://ecnu-cilab.github.io/ExVideoProjectPage/)
|
| 144 |
+
- Source code has been released in this repository. See [`examples/ExVideo`](./examples/ExVideo/).
|
| 145 |
+
- Model has been released at [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1) and [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1).
|
| 146 |
+
- Technical report has been released at [arXiv](https://arxiv.org/abs/2406.14130).
|
| 147 |
+
- You can try ExVideo in this [demo](https://huggingface.co/spaces/modelscope/ExVideo-SVD-128f-v1)!
|
| 148 |
+
|
| 149 |
+
- **June 13, 2024** DiffSynth Studio has migrated to ModelScope. The development team has also transitioned from "me" to "us". Of course, I will still participate in subsequent development and maintenance work.
|
| 150 |
+
|
| 151 |
+
- **January 29, 2024** We propose Diffutoon, an excellent cartoon coloring solution.
|
| 152 |
+
- [Project Page](https://ecnu-cilab.github.io/DiffutoonProjectPage/)
|
| 153 |
+
- Source code has been released in this project.
|
| 154 |
+
- Technical report (IJCAI 2024) has been released at [arXiv](https://arxiv.org/abs/2401.16224).
|
| 155 |
+
|
| 156 |
+
- **December 8, 2023** We decided to initiate a new project aimed at unleashing the potential of diffusion models, especially in video synthesis. The development work of this project officially began.
|
| 157 |
+
|
| 158 |
+
- **November 15, 2023** We propose FastBlend, a powerful video deflickering algorithm.
|
| 159 |
+
- sd-webui extension has been released at [GitHub](https://github.com/Artiprocher/sd-webui-fastblend).
|
| 160 |
+
- Demonstration videos have been showcased on Bilibili, including three tasks:
|
| 161 |
+
- [Video Deflickering](https://www.bilibili.com/video/BV1d94y1W7PE)
|
| 162 |
+
- [Video Interpolation](https://www.bilibili.com/video/BV1Lw411m71p)
|
| 163 |
+
- [Image-Driven Video Rendering](https://www.bilibili.com/video/BV1RB4y1Z7LF)
|
| 164 |
+
- Technical report has been released at [arXiv](https://arxiv.org/abs/2311.09265).
|
| 165 |
+
- Unofficial ComfyUI extensions developed by other users have been released at [GitHub](https://github.com/AInseven/ComfyUI-fastblend).
|
| 166 |
+
|
| 167 |
+
- **October 1, 2023** We released an early version of the project named FastSDXL. This was an initial attempt to build a diffusion engine.
|
| 168 |
+
- Source code has been released at [GitHub](https://github.com/Artiprocher/FastSDXL).
|
| 169 |
+
- FastSDXL includes a trainable OLSS scheduler to improve efficiency.
|
| 170 |
+
- The original repository of OLSS is located [here](https://github.com/alibaba/EasyNLP/tree/master/diffusion/olss_scheduler).
|
| 171 |
+
- Technical report (CIKM 2023) has been released at [arXiv](https://arxiv.org/abs/2305.14677).
|
| 172 |
+
- Demonstration video has been released at [Bilibili](https://www.bilibili.com/video/BV1w8411y7uj).
|
| 173 |
+
- Since OLSS requires additional training, we did not implement it in this project.
|
| 174 |
+
|
| 175 |
+
- **August 29, 2023** We propose DiffSynth, a video synthesis framework.
|
| 176 |
+
- [Project Page](https://ecnu-cilab.github.io/DiffSynth.github.io/).
|
| 177 |
+
- Source code has been released at [EasyNLP](https://github.com/alibaba/EasyNLP/tree/master/diffusion/DiffSynth).
|
| 178 |
+
- Technical report (ECML PKDD 2024) has been released at [arXiv](https://arxiv.org/abs/2308.03463).
|
| 179 |
+
|
| 180 |
+
</details>
|
| 181 |
+
|
| 182 |
+
## Installation
|
| 183 |
+
|
| 184 |
+
Install from source (recommended):
|
| 185 |
+
|
| 186 |
+
```
|
| 187 |
+
git clone https://github.com/modelscope/DiffSynth-Studio.git
|
| 188 |
+
cd DiffSynth-Studio
|
| 189 |
+
pip install -e .
|
| 190 |
+
```
|
| 191 |
+
|
| 192 |
+
<details>
|
| 193 |
+
<summary>Other installation methods</summary>
|
| 194 |
+
|
| 195 |
+
Install from PyPI (version updates may be delayed; for latest features, install from source)
|
| 196 |
+
|
| 197 |
+
```
|
| 198 |
+
pip install diffsynth
|
| 199 |
+
```
|
| 200 |
+
|
| 201 |
+
If you meet problems during installation, they might be caused by upstream dependencies. Please check the docs of these packages:
|
| 202 |
+
|
| 203 |
+
* [torch](https://pytorch.org/get-started/locally/)
|
| 204 |
+
* [sentencepiece](https://github.com/google/sentencepiece)
|
| 205 |
+
* [cmake](https://cmake.org)
|
| 206 |
+
* [cupy](https://docs.cupy.dev/en/stable/install.html)
|
| 207 |
+
|
| 208 |
+
</details>
|
| 209 |
+
|
| 210 |
+
## Basic Framework
|
| 211 |
+
|
| 212 |
+
DiffSynth-Studio redesigns the inference and training pipelines for mainstream Diffusion models (including FLUX, Wan, etc.), enabling efficient memory management and flexible model training.
|
| 213 |
+
|
| 214 |
+
<details>
|
| 215 |
+
<summary>Environment Variable Configuration</summary>
|
| 216 |
+
|
| 217 |
+
> Before running model inference or training, you can configure settings such as the model download source via [environment variables](/docs/en/Pipeline_Usage/Environment_Variables.md).
|
| 218 |
+
>
|
| 219 |
+
> By default, this project downloads models from ModelScope. For users outside China, you can configure the system to download models from the ModelScope international site as follows:
|
| 220 |
+
>
|
| 221 |
+
> ```python
|
| 222 |
+
> import os
|
| 223 |
+
> os.environ["MODELSCOPE_DOMAIN"] = "www.modelscope.ai"
|
| 224 |
+
> ```
|
| 225 |
+
>
|
| 226 |
+
> To download models from other sources, please modify the environment variable [DIFFSYNTH_DOWNLOAD_SOURCE](/docs/en/Pipeline_Usage/Environment_Variables.md#diffsynth_download_source).
|
| 227 |
+
|
| 228 |
+
</details>
|
| 229 |
+
|
| 230 |
+
### Image Synthesis
|
| 231 |
+
|
| 232 |
+

|
| 233 |
+
|
| 234 |
+
#### Z-Image: [/docs/en/Model_Details/Z-Image.md](/docs/en/Model_Details/Z-Image.md)
|
| 235 |
+
|
| 236 |
+
<details>
|
| 237 |
+
|
| 238 |
+
<summary>Quick Start</summary>
|
| 239 |
+
|
| 240 |
+
Running the following code will quickly load the [Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo) model for inference. FP8 quantization significantly degrades image quality, so we do not recommend enabling any quantization for the Z-Image Turbo model. CPU offloading is recommended, and the model can run with as little as 8 GB of GPU memory.
|
| 241 |
+
|
| 242 |
+
```python
|
| 243 |
+
from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig
|
| 244 |
+
import torch
|
| 245 |
+
|
| 246 |
+
vram_config = {
|
| 247 |
+
"offload_dtype": torch.bfloat16,
|
| 248 |
+
"offload_device": "cpu",
|
| 249 |
+
"onload_dtype": torch.bfloat16,
|
| 250 |
+
"onload_device": "cpu",
|
| 251 |
+
"preparing_dtype": torch.bfloat16,
|
| 252 |
+
"preparing_device": "cuda",
|
| 253 |
+
"computation_dtype": torch.bfloat16,
|
| 254 |
+
"computation_device": "cuda",
|
| 255 |
+
}
|
| 256 |
+
pipe = ZImagePipeline.from_pretrained(
|
| 257 |
+
torch_dtype=torch.bfloat16,
|
| 258 |
+
device="cuda",
|
| 259 |
+
model_configs=[
|
| 260 |
+
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors", **vram_config),
|
| 261 |
+
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
|
| 262 |
+
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
| 263 |
+
],
|
| 264 |
+
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
|
| 265 |
+
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
| 266 |
+
)
|
| 267 |
+
prompt = "Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights."
|
| 268 |
+
image = pipe(prompt=prompt, seed=42, rand_device="cuda")
|
| 269 |
+
image.save("image.jpg")
|
| 270 |
+
```
|
| 271 |
+
|
| 272 |
+
</details>
|
| 273 |
+
|
| 274 |
+
<details>
|
| 275 |
+
|
| 276 |
+
<summary>Examples</summary>
|
| 277 |
+
|
| 278 |
+
Example code for Z-Image is available at: [/examples/z_image/](/examples/z_image/)
|
| 279 |
+
|
| 280 |
+
| Model ID | Inference | Low-VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|
| 281 |
+
|-|-|-|-|-|-|-|
|
| 282 |
+
|[Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo)|[code](/examples/z_image/model_inference/Z-Image-Turbo.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py)|
|
| 283 |
+
|
| 284 |
+
</details>
|
| 285 |
+
|
| 286 |
+
#### FLUX.2: [/docs/en/Model_Details/FLUX2.md](/docs/en/Model_Details/FLUX2.md)
|
| 287 |
+
|
| 288 |
+
<details>
|
| 289 |
+
|
| 290 |
+
<summary>Quick Start</summary>
|
| 291 |
+
|
| 292 |
+
Running the following code will quickly load the [black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev) model for inference. VRAM management is enabled, and the framework automatically loads model parameters based on available GPU memory. The model can run with as little as 10 GB of VRAM.
|
| 293 |
+
|
| 294 |
+
```python
|
| 295 |
+
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
|
| 296 |
+
import torch
|
| 297 |
+
|
| 298 |
+
vram_config = {
|
| 299 |
+
"offload_dtype": "disk",
|
| 300 |
+
"offload_device": "disk",
|
| 301 |
+
"onload_dtype": torch.float8_e4m3fn,
|
| 302 |
+
"onload_device": "cpu",
|
| 303 |
+
"preparing_dtype": torch.float8_e4m3fn,
|
| 304 |
+
"preparing_device": "cuda",
|
| 305 |
+
"computation_dtype": torch.bfloat16,
|
| 306 |
+
"computation_device": "cuda",
|
| 307 |
+
}
|
| 308 |
+
pipe = Flux2ImagePipeline.from_pretrained(
|
| 309 |
+
torch_dtype=torch.bfloat16,
|
| 310 |
+
device="cuda",
|
| 311 |
+
model_configs=[
|
| 312 |
+
ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
|
| 313 |
+
ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="transformer/*.safetensors", **vram_config),
|
| 314 |
+
ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
| 315 |
+
],
|
| 316 |
+
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="tokenizer/"),
|
| 317 |
+
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
| 318 |
+
)
|
| 319 |
+
prompt = "High resolution. A dreamy underwater portrait of a serene young woman in a flowing blue dress. Her hair floats softly around her face, strands delicately suspended in the water. Clear, shimmering light filters through, casting gentle highlights, while tiny bubbles rise around her. Her expression is calm, her features finely detailed—creating a tranquil, ethereal scene."
|
| 320 |
+
image = pipe(prompt, seed=42, rand_device="cuda", num_inference_steps=50)
|
| 321 |
+
image.save("image.jpg")
|
| 322 |
+
```
|
| 323 |
+
|
| 324 |
+
</details>
|
| 325 |
+
|
| 326 |
+
<details>
|
| 327 |
+
|
| 328 |
+
<summary>Examples</summary>
|
| 329 |
+
|
| 330 |
+
Example code for FLUX.2 is available at: [/examples/flux2/](/examples/flux2/)
|
| 331 |
+
|
| 332 |
+
| Model ID | Inference | Low-VRAM Inference | LoRA Training | LoRA Training Validation |
|
| 333 |
+
|-|-|-|-|-|
|
| 334 |
+
|[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](/examples/flux2/model_inference/FLUX.2-dev.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|[code](/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)|
|
| 335 |
+
|
| 336 |
+
</details>
|
| 337 |
+
|
| 338 |
+
#### Qwen-Image: [/docs/en/Model_Details/Qwen-Image.md](/docs/en/Model_Details/Qwen-Image.md)
|
| 339 |
+
|
| 340 |
+
<details>
|
| 341 |
+
|
| 342 |
+
<summary>Quick Start</summary>
|
| 343 |
+
|
| 344 |
+
Running the following code will quickly load the [Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image) model for inference. VRAM management is enabled, and the framework automatically adjusts model parameter loading based on available GPU memory. The model can run with as little as 8 GB of VRAM.
|
| 345 |
+
|
| 346 |
+
```python
|
| 347 |
+
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
| 348 |
+
import torch
|
| 349 |
+
|
| 350 |
+
vram_config = {
|
| 351 |
+
"offload_dtype": "disk",
|
| 352 |
+
"offload_device": "disk",
|
| 353 |
+
"onload_dtype": torch.float8_e4m3fn,
|
| 354 |
+
"onload_device": "cpu",
|
| 355 |
+
"preparing_dtype": torch.float8_e4m3fn,
|
| 356 |
+
"preparing_device": "cuda",
|
| 357 |
+
"computation_dtype": torch.bfloat16,
|
| 358 |
+
"computation_device": "cuda",
|
| 359 |
+
}
|
| 360 |
+
pipe = QwenImagePipeline.from_pretrained(
|
| 361 |
+
torch_dtype=torch.bfloat16,
|
| 362 |
+
device="cuda",
|
| 363 |
+
model_configs=[
|
| 364 |
+
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
|
| 365 |
+
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config),
|
| 366 |
+
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
| 367 |
+
],
|
| 368 |
+
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
| 369 |
+
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
| 370 |
+
)
|
| 371 |
+
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
|
| 372 |
+
image = pipe(prompt, seed=0, num_inference_steps=40)
|
| 373 |
+
image.save("image.jpg")
|
| 374 |
+
```
|
| 375 |
+
|
| 376 |
+
</details>
|
| 377 |
+
|
| 378 |
+
<details>
|
| 379 |
+
|
| 380 |
+
<summary>Model Lineage</summary>
|
| 381 |
+
|
| 382 |
+
```mermaid
|
| 383 |
+
graph LR;
|
| 384 |
+
Qwen/Qwen-Image-->Qwen/Qwen-Image-Edit;
|
| 385 |
+
Qwen/Qwen-Image-Edit-->Qwen/Qwen-Image-Edit-2509;
|
| 386 |
+
Qwen/Qwen-Image-->EliGen-Series;
|
| 387 |
+
EliGen-Series-->DiffSynth-Studio/Qwen-Image-EliGen;
|
| 388 |
+
DiffSynth-Studio/Qwen-Image-EliGen-->DiffSynth-Studio/Qwen-Image-EliGen-V2;
|
| 389 |
+
EliGen-Series-->DiffSynth-Studio/Qwen-Image-EliGen-Poster;
|
| 390 |
+
Qwen/Qwen-Image-->Distill-Series;
|
| 391 |
+
Distill-Series-->DiffSynth-Studio/Qwen-Image-Distill-Full;
|
| 392 |
+
Distill-Series-->DiffSynth-Studio/Qwen-Image-Distill-LoRA;
|
| 393 |
+
Qwen/Qwen-Image-->ControlNet-Series;
|
| 394 |
+
ControlNet-Series-->Blockwise-ControlNet-Series;
|
| 395 |
+
Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny;
|
| 396 |
+
Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth;
|
| 397 |
+
Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint;
|
| 398 |
+
ControlNet-Series-->DiffSynth-Studio/Qwen-Image-In-Context-Control-Union;
|
| 399 |
+
Qwen/Qwen-Image-->DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix;
|
| 400 |
+
```
|
| 401 |
+
|
| 402 |
+
</details>
|
| 403 |
+
|
| 404 |
+
<details>
|
| 405 |
+
|
| 406 |
+
<summary>Examples</summary>
|
| 407 |
+
|
| 408 |
+
Example code for Qwen-Image is available at: [/examples/qwen_image/](/examples/qwen_image/)
|
| 409 |
+
|
| 410 |
+
| Model ID | Inference | Low-VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|
| 411 |
+
|-|-|-|-|-|-|-|
|
| 412 |
+
|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](/examples/qwen_image/model_inference/Qwen-Image.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image.py)|
|
| 413 |
+
|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)|
|
| 414 |
+
|[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)|
|
| 415 |
+
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
| 416 |
+
|[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
| 417 |
+
|[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)|
|
| 418 |
+
|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](/examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-Full.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py)|
|
| 419 |
+
|[DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA)|[code](/examples/qwen_image/model_inference/Qwen-Image-Distill-LoRA.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-LoRA.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-LoRA.py)|
|
| 420 |
+
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py)|
|
| 421 |
+
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py)|
|
| 422 |
+
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|
|
| 423 |
+
|[DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union)|[code](/examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py)|
|
| 424 |
+
|[DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py)|-|-|-|-|
|
| 425 |
+
|[DiffSynth-Studio/Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L)|[code](/examples/qwen_image/model_inference/Qwen-Image-i2L.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-i2L.py)|-|-|-|-|
|
| 426 |
+
|
| 427 |
+
</details>
|
| 428 |
+
|
| 429 |
+
#### FLUX.1: [/docs/en/Model_Details/FLUX.md](/docs/en/Model_Details/FLUX.md)
|
| 430 |
+
|
| 431 |
+
<details>
|
| 432 |
+
|
| 433 |
+
<summary>Quick Start</summary>
|
| 434 |
+
|
| 435 |
+
Running the following code will quickly load the [black-forest-labs/FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev) model for inference. VRAM management is enabled, and the framework automatically adjusts model parameter loading based on available GPU memory. The model can run with as little as 8 GB of VRAM.
|
| 436 |
+
|
| 437 |
+
```python
|
| 438 |
+
import torch
|
| 439 |
+
from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig
|
| 440 |
+
|
| 441 |
+
vram_config = {
|
| 442 |
+
"offload_dtype": torch.float8_e4m3fn,
|
| 443 |
+
"offload_device": "cpu",
|
| 444 |
+
"onload_dtype": torch.float8_e4m3fn,
|
| 445 |
+
"onload_device": "cpu",
|
| 446 |
+
"preparing_dtype": torch.float8_e4m3fn,
|
| 447 |
+
"preparing_device": "cuda",
|
| 448 |
+
"computation_dtype": torch.bfloat16,
|
| 449 |
+
"computation_device": "cuda",
|
| 450 |
+
}
|
| 451 |
+
pipe = FluxImagePipeline.from_pretrained(
|
| 452 |
+
torch_dtype=torch.bfloat16,
|
| 453 |
+
device="cuda",
|
| 454 |
+
model_configs=[
|
| 455 |
+
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors", **vram_config),
|
| 456 |
+
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
|
| 457 |
+
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors", **vram_config),
|
| 458 |
+
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", **vram_config),
|
| 459 |
+
],
|
| 460 |
+
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 1,
|
| 461 |
+
)
|
| 462 |
+
prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her."
|
| 463 |
+
image = pipe(prompt=prompt, seed=0)
|
| 464 |
+
image.save("image.jpg")
|
| 465 |
+
```
|
| 466 |
+
|
| 467 |
+
</details>
|
| 468 |
+
|
| 469 |
+
<details>
|
| 470 |
+
|
| 471 |
+
<summary>Model Lineage</summary>
|
| 472 |
+
|
| 473 |
+
```mermaid
|
| 474 |
+
graph LR;
|
| 475 |
+
FLUX.1-Series-->black-forest-labs/FLUX.1-dev;
|
| 476 |
+
FLUX.1-Series-->black-forest-labs/FLUX.1-Krea-dev;
|
| 477 |
+
FLUX.1-Series-->black-forest-labs/FLUX.1-Kontext-dev;
|
| 478 |
+
black-forest-labs/FLUX.1-dev-->FLUX.1-dev-ControlNet-Series;
|
| 479 |
+
FLUX.1-dev-ControlNet-Series-->alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta;
|
| 480 |
+
FLUX.1-dev-ControlNet-Series-->InstantX/FLUX.1-dev-Controlnet-Union-alpha;
|
| 481 |
+
FLUX.1-dev-ControlNet-Series-->jasperai/Flux.1-dev-Controlnet-Upscaler;
|
| 482 |
+
black-forest-labs/FLUX.1-dev-->InstantX/FLUX.1-dev-IP-Adapter;
|
| 483 |
+
black-forest-labs/FLUX.1-dev-->ByteDance/InfiniteYou;
|
| 484 |
+
black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/Eligen;
|
| 485 |
+
black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev;
|
| 486 |
+
black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev;
|
| 487 |
+
black-forest-labs/FLUX.1-dev-->ostris/Flex.2-preview;
|
| 488 |
+
black-forest-labs/FLUX.1-dev-->stepfun-ai/Step1X-Edit;
|
| 489 |
+
Qwen/Qwen2.5-VL-7B-Instruct-->stepfun-ai/Step1X-Edit;
|
| 490 |
+
black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/Nexus-GenV2;
|
| 491 |
+
Qwen/Qwen2.5-VL-7B-Instruct-->DiffSynth-Studio/Nexus-GenV2;
|
| 492 |
+
```
|
| 493 |
+
|
| 494 |
+
</details>
|
| 495 |
+
|
| 496 |
+
<details>
|
| 497 |
+
|
| 498 |
+
<summary>Examples</summary>
|
| 499 |
+
|
| 500 |
+
Example code for FLUX.1 is available at: [/examples/flux/](/examples/flux/)
|
| 501 |
+
|
| 502 |
+
| Model ID | Extra Args | Inference | Low-VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|
| 503 |
+
|-|-|-|-|-|-|-|-|
|
| 504 |
+
|[black-forest-labs/FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev)||[code](/examples/flux/model_inference/FLUX.1-dev.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev.py)|[code](/examples/flux/model_training/full/FLUX.1-dev.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev.py)|
|
| 505 |
+
|[black-forest-labs/FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev)||[code](/examples/flux/model_inference/FLUX.1-Krea-dev.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-Krea-dev.py)|[code](/examples/flux/model_training/full/FLUX.1-Krea-dev.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-Krea-dev.py)|[code](/examples/flux/model_training/lora/FLUX.1-Krea-dev.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-Krea-dev.py)|
|
| 506 |
+
|[black-forest-labs/FLUX.1-Kontext-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev)|`kontext_images`|[code](/examples/flux/model_inference/FLUX.1-Kontext-dev.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-Kontext-dev.py)|[code](/examples/flux/model_training/full/FLUX.1-Kontext-dev.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-Kontext-dev.py)|[code](/examples/flux/model_training/lora/FLUX.1-Kontext-dev.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-Kontext-dev.py)|
|
| 507 |
+
|[alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta](https://www.modelscope.cn/models/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta)|`controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|
|
| 508 |
+
|[InstantX/FLUX.1-dev-Controlnet-Union-alpha](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-Controlnet-Union-alpha)|`controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Union-alpha.py)|
|
| 509 |
+
|[jasperai/Flux.1-dev-Controlnet-Upscaler](https://www.modelscope.cn/models/jasperai/Flux.1-dev-Controlnet-Upscaler)|`controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Upscaler.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Upscaler.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Upscaler.py)|
|
| 510 |
+
|[InstantX/FLUX.1-dev-IP-Adapter](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-IP-Adapter)|`ipadapter_images`, `ipadapter_scale`|[code](/examples/flux/model_inference/FLUX.1-dev-IP-Adapter.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-IP-Adapter.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-IP-Adapter.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-IP-Adapter.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-IP-Adapter.py)|
|
| 511 |
+
|[ByteDance/InfiniteYou](https://www.modelscope.cn/models/ByteDance/InfiniteYou)|`infinityou_id_image`, `infinityou_guidance`, `controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-InfiniteYou.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-InfiniteYou.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-InfiniteYou.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-InfiniteYou.py)|
|
| 512 |
+
|[DiffSynth-Studio/Eligen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen)|`eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint`|[code](/examples/flux/model_inference/FLUX.1-dev-EliGen.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py)|-|-|[code](/examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-EliGen.py)|
|
| 513 |
+
|[DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev](https://www.modelscope.cn/models/DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev)|`lora_encoder_inputs`, `lora_encoder_scale`|[code](/examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-LoRA-Encoder.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py)|-|-|
|
| 514 |
+
|[DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev)||[code](/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py)|-|-|-|-|-|
|
| 515 |
+
|[stepfun-ai/Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit)|`step1x_reference_image`|[code](/examples/flux/model_inference/Step1X-Edit.py)|[code](/examples/flux/model_inference_low_vram/Step1X-Edit.py)|[code](/examples/flux/model_training/full/Step1X-Edit.sh)|[code](/examples/flux/model_training/validate_full/Step1X-Edit.py)|[code](/examples/flux/model_training/lora/Step1X-Edit.sh)|[code](/examples/flux/model_training/validate_lora/Step1X-Edit.py)|
|
| 516 |
+
|[ostris/Flex.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview)|`flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop`|[code](/examples/flux/model_inference/FLEX.2-preview.py)|[code](/examples/flux/model_inference_low_vram/FLEX.2-preview.py)|[code](/examples/flux/model_training/full/FLEX.2-preview.sh)|[code](/examples/flux/model_training/validate_full/FLEX.2-preview.py)|[code](/examples/flux/model_training/lora/FLEX.2-preview.sh)|[code](/examples/flux/model_training/validate_lora/FLEX.2-preview.py)|
|
| 517 |
+
|[DiffSynth-Studio/Nexus-GenV2](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2)|`nexus_gen_reference_image`|[code](/examples/flux/model_inference/Nexus-Gen-Editing.py)|[code](/examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py)|[code](/examples/flux/model_training/full/Nexus-Gen.sh)|[code](/examples/flux/model_training/validate_full/Nexus-Gen.py)|[code](/examples/flux/model_training/lora/Nexus-Gen.sh)|[code](/examples/flux/model_training/validate_lora/Nexus-Gen.py)|
|
| 518 |
+
|
| 519 |
+
</details>
|
| 520 |
+
|
| 521 |
+
### Video Synthesis
|
| 522 |
+
|
| 523 |
+
https://github.com/user-attachments/assets/1d66ae74-3b02-40a9-acc3-ea95fc039314
|
| 524 |
+
|
| 525 |
+
#### Wan: [/docs/en/Model_Details/Wan.md](/docs/en/Model_Details/Wan.md)
|
| 526 |
+
|
| 527 |
+
<details>
|
| 528 |
+
|
| 529 |
+
<summary>Quick Start</summary>
|
| 530 |
+
|
| 531 |
+
Running the following code will quickly load the [Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) model for inference. VRAM management is enabled, and the framework automatically adjusts model parameter loading based on available GPU memory. The model can run with as little as 8 GB of VRAM.
|
| 532 |
+
|
| 533 |
+
```python
|
| 534 |
+
import torch
|
| 535 |
+
from diffsynth.utils.data import save_video, VideoData
|
| 536 |
+
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
|
| 537 |
+
|
| 538 |
+
vram_config = {
|
| 539 |
+
"offload_dtype": "disk",
|
| 540 |
+
"offload_device": "disk",
|
| 541 |
+
"onload_dtype": torch.bfloat16,
|
| 542 |
+
"onload_device": "cpu",
|
| 543 |
+
"preparing_dtype": torch.bfloat16,
|
| 544 |
+
"preparing_device": "cuda",
|
| 545 |
+
"computation_dtype": torch.bfloat16,
|
| 546 |
+
"computation_device": "cuda",
|
| 547 |
+
}
|
| 548 |
+
pipe = WanVideoPipeline.from_pretrained(
|
| 549 |
+
torch_dtype=torch.bfloat16,
|
| 550 |
+
device="cuda",
|
| 551 |
+
model_configs=[
|
| 552 |
+
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", **vram_config),
|
| 553 |
+
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", **vram_config),
|
| 554 |
+
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", **vram_config),
|
| 555 |
+
],
|
| 556 |
+
tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"),
|
| 557 |
+
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2,
|
| 558 |
+
)
|
| 559 |
+
|
| 560 |
+
video = pipe(
|
| 561 |
+
prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
|
| 562 |
+
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
| 563 |
+
seed=0, tiled=True,
|
| 564 |
+
)
|
| 565 |
+
save_video(video, "video.mp4", fps=15, quality=5)
|
| 566 |
+
```
|
| 567 |
+
|
| 568 |
+
</details>
|
| 569 |
+
|
| 570 |
+
<details>
|
| 571 |
+
|
| 572 |
+
<summary>Model Lineage</summary>
|
| 573 |
+
|
| 574 |
+
```mermaid
|
| 575 |
+
graph LR;
|
| 576 |
+
Wan-Series-->Wan2.1-Series;
|
| 577 |
+
Wan-Series-->Wan2.2-Series;
|
| 578 |
+
Wan2.1-Series-->Wan-AI/Wan2.1-T2V-1.3B;
|
| 579 |
+
Wan2.1-Series-->Wan-AI/Wan2.1-T2V-14B;
|
| 580 |
+
Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-I2V-14B-480P;
|
| 581 |
+
Wan-AI/Wan2.1-I2V-14B-480P-->Wan-AI/Wan2.1-I2V-14B-720P;
|
| 582 |
+
Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-FLF2V-14B-720P;
|
| 583 |
+
Wan-AI/Wan2.1-T2V-1.3B-->iic/VACE-Wan2.1-1.3B-Preview;
|
| 584 |
+
iic/VACE-Wan2.1-1.3B-Preview-->Wan-AI/Wan2.1-VACE-1.3B;
|
| 585 |
+
Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-VACE-14B;
|
| 586 |
+
Wan-AI/Wan2.1-T2V-1.3B-->Wan2.1-Fun-1.3B-Series;
|
| 587 |
+
Wan2.1-Fun-1.3B-Series-->PAI/Wan2.1-Fun-1.3B-InP;
|
| 588 |
+
Wan2.1-Fun-1.3B-Series-->PAI/Wan2.1-Fun-1.3B-Control;
|
| 589 |
+
Wan-AI/Wan2.1-T2V-14B-->Wan2.1-Fun-14B-Series;
|
| 590 |
+
Wan2.1-Fun-14B-Series-->PAI/Wan2.1-Fun-14B-InP;
|
| 591 |
+
Wan2.1-Fun-14B-Series-->PAI/Wan2.1-Fun-14B-Control;
|
| 592 |
+
Wan-AI/Wan2.1-T2V-1.3B-->Wan2.1-Fun-V1.1-1.3B-Series;
|
| 593 |
+
Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-Control;
|
| 594 |
+
Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-InP;
|
| 595 |
+
Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera;
|
| 596 |
+
Wan-AI/Wan2.1-T2V-14B-->Wan2.1-Fun-V1.1-14B-Series;
|
| 597 |
+
Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-Control;
|
| 598 |
+
Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-InP;
|
| 599 |
+
Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-Control-Camera;
|
| 600 |
+
Wan-AI/Wan2.1-T2V-1.3B-->DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1;
|
| 601 |
+
Wan-AI/Wan2.1-T2V-14B-->krea/krea-realtime-video;
|
| 602 |
+
Wan-AI/Wan2.1-T2V-14B-->meituan-longcat/LongCat-Video;
|
| 603 |
+
Wan-AI/Wan2.1-I2V-14B-720P-->ByteDance/Video-As-Prompt-Wan2.1-14B;
|
| 604 |
+
Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.2-Animate-14B;
|
| 605 |
+
Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.2-S2V-14B;
|
| 606 |
+
Wan2.2-Series-->Wan-AI/Wan2.2-T2V-A14B;
|
| 607 |
+
Wan2.2-Series-->Wan-AI/Wan2.2-I2V-A14B;
|
| 608 |
+
Wan2.2-Series-->Wan-AI/Wan2.2-TI2V-5B;
|
| 609 |
+
Wan-AI/Wan2.2-T2V-A14B-->Wan2.2-Fun-Series;
|
| 610 |
+
Wan2.2-Fun-Series-->PAI/Wan2.2-VACE-Fun-A14B;
|
| 611 |
+
Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-InP;
|
| 612 |
+
Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-Control;
|
| 613 |
+
Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-Control-Camera;
|
| 614 |
+
```
|
| 615 |
+
|
| 616 |
+
</details>
|
| 617 |
+
|
| 618 |
+
<details>
|
| 619 |
+
|
| 620 |
+
<summary>Examples</summary>
|
| 621 |
+
|
| 622 |
+
Example code for Wan is available at: [/examples/wanvideo/](/examples/wanvideo/)
|
| 623 |
+
|
| 624 |
+
| Model ID | Extra Args | Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|
| 625 |
+
|-|-|-|-|-|-|-|
|
| 626 |
+
|[Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)||[code](/examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py)|
|
| 627 |
+
|[Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)||[code](/examples/wanvideo/model_inference/Wan2.1-T2V-14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py)|
|
| 628 |
+
|[Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py)|
|
| 629 |
+
|[Wan-AI/Wan2.1-I2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-720P.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-720P.py)|
|
| 630 |
+
|[Wan-AI/Wan2.1-FLF2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-FLF2V-14B-720P.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py)|
|
| 631 |
+
|[iic/VACE-Wan2.1-1.3B-Preview](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B-Preview.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py)|
|
| 632 |
+
|[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B.py)|
|
| 633 |
+
|[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-VACE-14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py)|
|
| 634 |
+
|[PAI/Wan2.1-Fun-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py)|
|
| 635 |
+
|[PAI/Wan2.1-Fun-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control)|`control_video`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py)|
|
| 636 |
+
|[PAI/Wan2.1-Fun-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-InP.py)|
|
| 637 |
+
|[PAI/Wan2.1-Fun-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control)|`control_video`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-Control.py)|
|
| 638 |
+
|[PAI/Wan2.1-Fun-V1.1-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py)|
|
| 639 |
+
|[PAI/Wan2.1-Fun-V1.1-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py)|
|
| 640 |
+
|[PAI/Wan2.1-Fun-V1.1-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py)|
|
| 641 |
+
|[PAI/Wan2.1-Fun-V1.1-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py)|
|
| 642 |
+
|[PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|
|
| 643 |
+
|[PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|
|
| 644 |
+
|[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](/examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)|
|
| 645 |
+
|[krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video)||[code](/examples/wanvideo/model_inference/krea-realtime-video.py)|[code](/examples/wanvideo/model_training/full/krea-realtime-video.sh)|[code](/examples/wanvideo/model_training/validate_full/krea-realtime-video.py)|[code](/examples/wanvideo/model_training/lora/krea-realtime-video.sh)|[code](/examples/wanvideo/model_training/validate_lora/krea-realtime-video.py)|
|
| 646 |
+
|[meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video)|`longcat_video`|[code](/examples/wanvideo/model_inference/LongCat-Video.py)|[code](/examples/wanvideo/model_training/full/LongCat-Video.sh)|[code](/examples/wanvideo/model_training/validate_full/LongCat-Video.py)|[code](/examples/wanvideo/model_training/lora/LongCat-Video.sh)|[code](/examples/wanvideo/model_training/validate_lora/LongCat-Video.py)|
|
| 647 |
+
|[ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B)|`vap_video`, `vap_prompt`|[code](/examples/wanvideo/model_inference/Video-As-Prompt-Wan2.1-14B.py)|[code](/examples/wanvideo/model_training/full/Video-As-Prompt-Wan2.1-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Video-As-Prompt-Wan2.1-14B.py)|[code](/examples/wanvideo/model_training/lora/Video-As-Prompt-Wan2.1-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Video-As-Prompt-Wan2.1-14B.py)|
|
| 648 |
+
|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)|
|
| 649 |
+
|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)|
|
| 650 |
+
|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py)|
|
| 651 |
+
|[Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B)|`input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video`|[code](/examples/wanvideo/model_inference/Wan2.2-Animate-14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py)|
|
| 652 |
+
|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](/examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-S2V-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-S2V-14B.py)|
|
| 653 |
+
|[PAI/Wan2.2-VACE-Fun-A14B](https://www.modelscope.cn/models/PAI/Wan2.2-VACE-Fun-A14B)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.2-VACE-Fun-A14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-VACE-Fun-A14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-VACE-Fun-A14B.py)|
|
| 654 |
+
|[PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py)|
|
| 655 |
+
|[PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py)|
|
| 656 |
+
|[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)|
|
| 657 |
+
|
| 658 |
+
</details>
|
| 659 |
+
|
| 660 |
+
## Innovative Achievements
|
| 661 |
+
|
| 662 |
+
DiffSynth-Studio is not just an engineered model framework, but also an incubator for innovative achievements.
|
| 663 |
+
|
| 664 |
+
<details>
|
| 665 |
+
|
| 666 |
+
<summary>AttriCtrl: Attribute Intensity Control for Image Generation Models</summary>
|
| 667 |
+
|
| 668 |
+
- Paper: [AttriCtrl: Fine-Grained Control of Aesthetic Attribute Intensity in Diffusion Models](https://arxiv.org/abs/2508.02151)
|
| 669 |
+
- Sample Code: [/examples/flux/model_inference/FLUX.1-dev-AttriCtrl.py](/examples/flux/model_inference/FLUX.1-dev-AttriCtrl.py)
|
| 670 |
+
- Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev)
|
| 671 |
+
|
| 672 |
+
|brightness scale = 0.1|brightness scale = 0.3|brightness scale = 0.5|brightness scale = 0.7|brightness scale = 0.9|
|
| 673 |
+
|-|-|-|-|-|
|
| 674 |
+
||||||
|
| 675 |
+
|
| 676 |
+
</details>
|
| 677 |
+
|
| 678 |
+
|
| 679 |
+
<details>
|
| 680 |
+
|
| 681 |
+
<summary>AutoLoRA: Automated LoRA Retrieval and Fusion</summary>
|
| 682 |
+
|
| 683 |
+
- Paper: [AutoLoRA: Automatic LoRA Retrieval and Fine-Grained Gated Fusion for Text-to-Image Generation](https://arxiv.org/abs/2508.02107)
|
| 684 |
+
- Sample Code: [/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py](/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py)
|
| 685 |
+
- Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev)
|
| 686 |
+
|
| 687 |
+
||[LoRA 1](https://modelscope.cn/models/cancel13/cxsk)|[LoRA 2](https://modelscope.cn/models/wy413928499/xuancai2)|[LoRA 3](https://modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1)|[LoRA 4](https://modelscope.cn/models/hongyanbujian/JPL)|
|
| 688 |
+
|-|-|-|-|-|
|
| 689 |
+
|[LoRA 1](https://modelscope.cn/models/cancel13/cxsk) |||||
|
| 690 |
+
|[LoRA 2](https://modelscope.cn/models/wy413928499/xuancai2) |||||
|
| 691 |
+
|[LoRA 3](https://modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1) |||||
|
| 692 |
+
|[LoRA 4](https://modelscope.cn/models/hongyanbujian/JPL) |||||
|
| 693 |
+
|
| 694 |
+
</details>
|
| 695 |
+
|
| 696 |
+
|
| 697 |
+
<details>
|
| 698 |
+
|
| 699 |
+
<summary>Nexus-Gen: Unified Architecture for Image Understanding, Generation, and Editing</summary>
|
| 700 |
+
|
| 701 |
+
- Detailed Page: https://github.com/modelscope/Nexus-Gen
|
| 702 |
+
- Paper: [Nexus-Gen: Unified Image Understanding, Generation, and Editing via Prefilled Autoregression in Shared Embedding Space](https://arxiv.org/pdf/2504.21356)
|
| 703 |
+
- Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2), [HuggingFace](https://huggingface.co/modelscope/Nexus-GenV2)
|
| 704 |
+
- Dataset: [ModelScope Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Nexus-Gen-Training-Dataset)
|
| 705 |
+
- Online Experience: [ModelScope Nexus-Gen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/Nexus-Gen)
|
| 706 |
+
|
| 707 |
+

|
| 708 |
+
|
| 709 |
+
</details>
|
| 710 |
+
|
| 711 |
+
|
| 712 |
+
<details>
|
| 713 |
+
|
| 714 |
+
<summary>ArtAug: Aesthetic Enhancement for Image Generation Models</summary>
|
| 715 |
+
|
| 716 |
+
- Detailed Page: [./examples/ArtAug/](./examples/ArtAug/)
|
| 717 |
+
- Paper: [ArtAug: Enhancing Text-to-Image Generation through Synthesis-Understanding Interaction](https://arxiv.org/abs/2412.12888)
|
| 718 |
+
- Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1), [HuggingFace](https://huggingface.co/ECNU-CILab/ArtAug-lora-FLUX.1dev-v1)
|
| 719 |
+
- Online Experience: [ModelScope AIGC Tab](https://www.modelscope.cn/aigc/imageGeneration?tab=advanced&versionId=7228&modelType=LoRA&sdVersion=FLUX_1&modelUrl=modelscope%3A%2F%2FDiffSynth-Studio%2FArtAug-lora-FLUX.1dev-v1%3Frevision%3Dv1.0)
|
| 720 |
+
|
| 721 |
+
|FLUX.1-dev|FLUX.1-dev + ArtAug LoRA|
|
| 722 |
+
|-|-|
|
| 723 |
+
|||
|
| 724 |
+
|
| 725 |
+
</details>
|
| 726 |
+
|
| 727 |
+
|
| 728 |
+
<details>
|
| 729 |
+
|
| 730 |
+
<summary>EliGen: Precise Image Partition Control</summary>
|
| 731 |
+
|
| 732 |
+
- Paper: [EliGen: Entity-Level Controlled Image Generation with Regional Attention](https://arxiv.org/abs/2501.01097)
|
| 733 |
+
- Sample Code: [/examples/flux/model_inference/FLUX.1-dev-EliGen.py](/examples/flux/model_inference/FLUX.1-dev-EliGen.py)
|
| 734 |
+
- Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen), [HuggingFace](https://huggingface.co/modelscope/EliGen)
|
| 735 |
+
- Online Experience: [ModelScope EliGen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/EliGen)
|
| 736 |
+
- Dataset: [EliGen Train Set](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)
|
| 737 |
+
|
| 738 |
+
|Entity Control Region|Generated Image|
|
| 739 |
+
|-|-|
|
| 740 |
+
|||
|
| 741 |
+
|
| 742 |
+
</details>
|
| 743 |
+
|
| 744 |
+
|
| 745 |
+
<details>
|
| 746 |
+
|
| 747 |
+
<summary>ExVideo: Extended Training for Video Generation Models</summary>
|
| 748 |
+
|
| 749 |
+
- Project Page: [Project Page](https://ecnu-cilab.github.io/ExVideoProjectPage/)
|
| 750 |
+
- Paper: [ExVideo: Extending Video Diffusion Models via Parameter-Efficient Post-Tuning](https://arxiv.org/abs/2406.14130)
|
| 751 |
+
- Sample Code: Please refer to the [older version](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3/examples/ExVideo)
|
| 752 |
+
- Model: [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1), [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1)
|
| 753 |
+
|
| 754 |
+
https://github.com/modelscope/DiffSynth-Studio/assets/35051019/d97f6aa9-8064-4b5b-9d49-ed6001bb9acc
|
| 755 |
+
|
| 756 |
+
</details>
|
| 757 |
+
|
| 758 |
+
|
| 759 |
+
<details>
|
| 760 |
+
|
| 761 |
+
<summary>Diffutoon: High-Resolution Anime-Style Video Rendering</summary>
|
| 762 |
+
|
| 763 |
+
- Project Page: [Project Page](https://ecnu-cilab.github.io/DiffutoonProjectPage/)
|
| 764 |
+
- Paper: [Diffutoon: High-Resolution Editable Toon Shading via Diffusion Models](https://arxiv.org/abs/2401.16224)
|
| 765 |
+
- Sample Code: Please refer to the [older version](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3/examples/Diffutoon)
|
| 766 |
+
|
| 767 |
+
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/b54c05c5-d747-4709-be5e-b39af82404dd
|
| 768 |
+
|
| 769 |
+
</details>
|
| 770 |
+
|
| 771 |
+
|
| 772 |
+
<details>
|
| 773 |
+
|
| 774 |
+
<summary>DiffSynth: The Original Version of This Project</summary>
|
| 775 |
+
|
| 776 |
+
- Project Page: [Project Page](https://ecnu-cilab.github.io/DiffSynth.github.io/)
|
| 777 |
+
- Paper: [DiffSynth: Latent In-Iteration Deflickering for Realistic Video Synthesis](https://arxiv.org/abs/2308.03463)
|
| 778 |
+
- Sample Code: Please refer to the [older version](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3/examples/diffsynth)
|
| 779 |
+
|
| 780 |
+
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-4481-b79f-0c3a7361a1ea
|
| 781 |
+
|
| 782 |
+
</details>
|
| 783 |
+
|
README_zh.md
ADDED
|
@@ -0,0 +1,784 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DiffSynth-Studio
|
| 2 |
+
|
| 3 |
+
<a href="https://github.com/modelscope/DiffSynth-Studio"><img src=".github/workflows/logo.gif" title="Logo" style="max-width:100%;" width="55" /></a> <a href="https://trendshift.io/repositories/10946" target="_blank"><img src="https://trendshift.io/api/badge/repositories/10946" alt="modelscope%2FDiffSynth-Studio | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a></p>
|
| 4 |
+
|
| 5 |
+
[](https://pypi.org/project/DiffSynth/)
|
| 6 |
+
[](https://github.com/modelscope/DiffSynth-Studio/blob/master/LICENSE)
|
| 7 |
+
[](https://github.com/modelscope/DiffSynth-Studio/issues)
|
| 8 |
+
[](https://GitHub.com/modelscope/DiffSynth-Studio/pull/)
|
| 9 |
+
[](https://GitHub.com/modelscope/DiffSynth-Studio/commit/)
|
| 10 |
+
|
| 11 |
+
[Switch to English](./README.md)
|
| 12 |
+
|
| 13 |
+
## 简介
|
| 14 |
+
|
| 15 |
+
欢迎来到 Diffusion 模型的魔法世界!DiffSynth-Studio 是由[魔搭社区](https://www.modelscope.cn/)团队开发和维护的开源 Diffusion 模型引擎。我们期望以框架建设孵化技术创新,凝聚开源社区的力量,探索生成式模型技术的边界!
|
| 16 |
+
|
| 17 |
+
DiffSynth 目前包括两个开源项目:
|
| 18 |
+
* [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio): 聚焦于激进的技术探索,面向学术界,提供更前沿的模型能力支持。
|
| 19 |
+
* [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine): 聚焦于稳定的模型部署,面向工业界,提供更高的计算性能与更稳定的功能。
|
| 20 |
+
|
| 21 |
+
[DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio) 与 [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine) 是魔搭社区 AIGC 专区的核心引擎,欢迎体验我们精心打造的产品化功能:
|
| 22 |
+
|
| 23 |
+
* 魔搭社区 AIGC 专区 (面向中国用户): https://modelscope.cn/aigc/home
|
| 24 |
+
* ModelScope Civision (for global users): https://modelscope.ai/civision/home
|
| 25 |
+
|
| 26 |
+
> DiffSynth-Studio 文档:[中文版](/docs/zh/README.md)、[English version](/docs/en/README.md)
|
| 27 |
+
|
| 28 |
+
我们相信,一个完善的开源代码框架能够降低技术探索的门槛,我们基于这个代码库搞出了不少[有意思的技术](#创新成果)。或许你也有许多天马行空的构想,借助 DiffSynth-Studio,你可以快速实现这些想法。为此,我们为开发者准备了详细的文档,我们希望通过这些文档,帮助开发者理解 Diffusion 模型的原理,更期待与你一同拓展技术的边界。
|
| 29 |
+
|
| 30 |
+
## 更新历史
|
| 31 |
+
|
| 32 |
+
> DiffSynth-Studio 经历了大版本更新,部分旧功能已停止维护,如需使用旧版功能,请切换到大版本更新前的[最后一个历史版本](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3)。
|
| 33 |
+
|
| 34 |
+
> 目前本项目的开发人员有限,大部分工作由 [Artiprocher](https://github.com/Artiprocher) 负责,因此新功能的开发进展会比较缓慢,issue 的回复和解决速度有限,我们对此感到非常抱歉,请各位开发者理解。
|
| 35 |
+
|
| 36 |
+
- **2025年12月9日** 我们基于 DiffSynth-Studio 2.0 训练了一个疯狂的模型:[Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L)(Image to LoRA)。这一模型以图像为输入,以 LoRA 为输出。尽管这个版本的模型在泛化能力、细节保持能力等方面还有很大改进空间,我们将这些模型开源,以启发更多创新性的研究工作。
|
| 37 |
+
|
| 38 |
+
- **2025年12月4日** DiffSynth-Studio 2.0 发布!众多新功能上线
|
| 39 |
+
- [文档](/docs/zh/README.md)上线:我们的文档还在持续优化更新中
|
| 40 |
+
- [显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md)模块升级,支持 Layer 级别的 Disk Offload,同时释放内存与显存
|
| 41 |
+
- 新模型支持
|
| 42 |
+
- Z-Image Turbo: [模型](https://www.modelscope.ai/models/Tongyi-MAI/Z-Image-Turbo)、[文档](/docs/zh/Model_Details/Z-Image.md)、[代码](/examples/z_image/)
|
| 43 |
+
- FLUX.2-dev: [模型](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)、[文档](/docs/zh/Model_Details/FLUX2.md)、[代码](/examples/flux2/)
|
| 44 |
+
- 训练框架升级
|
| 45 |
+
- [拆分训练](/docs/zh/Training/Split_Training.md):支持自动化地将训练过程拆分为数据处理和训练两阶段(即使训练的是 ControlNet 或其他任意模型),在数据处理阶段进行文本编码、VAE 编码等不需要梯度回传的计算,在训练阶段处理其他计算。速度更快,显存需求更少。
|
| 46 |
+
- [差分 LoRA 训练](/docs/zh/Training/Differential_LoRA.md):这是我们曾在 [ArtAug](https://www.modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1) 中使用的训练技术,目前已可用于任意模型的 LoRA 训练。
|
| 47 |
+
- [FP8 训练](/docs/zh/Training/FP8_Precision.md):FP8 在训练中支持应用到任意非训练模型,即梯度关闭或者梯度仅影响 LoRA 权重的模型。
|
| 48 |
+
|
| 49 |
+
<details>
|
| 50 |
+
<summary>更多</summary>
|
| 51 |
+
|
| 52 |
+
- **2025年11月4日** 支持了 [ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B) 模型,该模型基于 Wan 2.1 训练,支持根据参考视频生成相应的动作。
|
| 53 |
+
|
| 54 |
+
- **2025年10月30日** 支持了 [meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video) 模型,该模型支持文生视频、图生视频、视频续写。这个模型在本项目中沿用 Wan 的框架进行推理和训练。
|
| 55 |
+
|
| 56 |
+
- **2025年10月27日** 支持了 [krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video) 模型,Wan 模型生态再添一员。
|
| 57 |
+
|
| 58 |
+
- **2025年9月23日** [DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster) 发布!本模型由我们与淘天体验设计团队联合研发并开源。模型基于 Qwen-Image 构建,专为电商海报场景设计,支持精确的分区布局控制。 请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)。
|
| 59 |
+
|
| 60 |
+
- **2025年9月9日** 我们的训练框架支持了多种训练模式,目前已适配 Qwen-Image,除标准 SFT 训练模式外,已支持 Direct Distill,请参考[我们的示例代码](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh)。这项功能是实验性的,我们将会继续完善已支持更全面的模型训练功能。
|
| 61 |
+
|
| 62 |
+
- **2025年8月28日** 我们支持了Wan2.2-S2V,一个音频驱动的电影级视频生成模型。请参见[./examples/wanvideo/](./examples/wanvideo/)。
|
| 63 |
+
|
| 64 |
+
- **2025年8月21日** [DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2) 发布!相比于 V1 版本,训练数据集变为 [Qwen-Image-Self-Generated-Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Qwen-Image-Self-Generated-Dataset),因此,生成的图像更符合 Qwen-Image 本身的图像分布和风格。 请参考[我们的示例代码](./examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)。
|
| 65 |
+
|
| 66 |
+
- **2025年8月21日** 我们开源了 [DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union) 结构控制 LoRA 模型,采用 In Context 的技术路线,支持多种类别的结构控制条件,包括 canny, depth, lineart, softedge, normal, openpose。 请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py)。
|
| 67 |
+
|
| 68 |
+
- **2025年8月20日** 我们开源了 [DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix) 模型,提升了 Qwen-Image-Edit 对低分辨率图像输入的编辑效果。请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py)
|
| 69 |
+
|
| 70 |
+
- **2025年8月19日** 🔥 Qwen-Image-Edit 开源,欢迎图像编辑模型新成员!
|
| 71 |
+
|
| 72 |
+
- **2025年8月18日** 我们训练并开源了 Qwen-Image 的图像重绘 ControlNet 模型 [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint),模型结构采用了轻量化的设计,请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)。
|
| 73 |
+
|
| 74 |
+
- **2025年8月15日** 我们开源了 [Qwen-Image-Self-Generated-Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Qwen-Image-Self-Generated-Dataset) 数据集。这是一个使用 Qwen-Image 模型生成的图像数据集,共包含 160,000 张`1024 x 1024`图像。它包括通用、英文文本渲染和中文文本渲染子集。我们为每张图像提供了图像描述、实体和结构控制图像的标注。开发者可以使用这个数据集来训练 Qwen-Image 模型的 ControlNet 和 EliGen 等模型,我们旨在通过开源推动技术发展!
|
| 75 |
+
|
| 76 |
+
- **2025年8月13日** 我们训练并开源了 Qwen-Image 的 ControlNet 模型 [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth),模型结构采用了轻量化的设计,请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py)。
|
| 77 |
+
|
| 78 |
+
- **2025年8月12日** 我们训练并开源了 Qwen-Image 的 ControlNet 模型 [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny),模型结构采用了轻量化的设计,请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py)。
|
| 79 |
+
|
| 80 |
+
- **2025年8月11日** 我们开源了 Qwen-Image 的蒸馏加速模型 [DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA),沿用了与 [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full) 相同的训练流程,但模型结构修改为了 LoRA,因此能够更好地与其他开源生态模型兼容。
|
| 81 |
+
|
| 82 |
+
- **2025年8月7日** 我们开源了 Qwen-Image 的实体控制 LoRA 模型 [DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)。Qwen-Image-EliGen 能够实现实体级可控的文生图。技术细节请参见[论文](https://arxiv.org/abs/2501.01097)。训练数据集:[EliGenTrainSet](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)。
|
| 83 |
+
|
| 84 |
+
- **2025年8月5日** 我们开源了 Qwen-Image 的蒸馏加速模型 [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full),实现了约 5 倍加速。
|
| 85 |
+
|
| 86 |
+
- **2025年8月4日** 🔥 Qwen-Image 开源,欢迎图像生成模型家族新成员!
|
| 87 |
+
|
| 88 |
+
- **2025年8月1日** [FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev) 开源,这是一个专注于美学摄影的文生图模型。我们第一时间提供了全方位支持,包括低显存逐层 offload、LoRA 训练、全量训练。详细信息请参考 [./examples/flux/](./examples/flux/)。
|
| 89 |
+
|
| 90 |
+
- **2025年7月28日** Wan 2.2 开源,我们第一时间提供了全方位支持,包括低显存逐层 offload、FP8 量化、序列并行、LoRA 训练、全量训练。详细信息请参考 [./examples/wanvideo/](./examples/wanvideo/)。
|
| 91 |
+
|
| 92 |
+
- **2025年7月11日** 我们提出 Nexus-Gen,一个将大语言模型(LLM)的语言推理能力与扩散模型的图像生成能力相结合的统一框架。该框架支持无缝的图像理解、生成和编辑任务。
|
| 93 |
+
- 论文: [Nexus-Gen: Unified Image Understanding, Generation, and Editing via Prefilled Autoregression in Shared Embedding Space](https://arxiv.org/pdf/2504.21356)
|
| 94 |
+
- Github 仓库: https://github.com/modelscope/Nexus-Gen
|
| 95 |
+
- 模型: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2), [HuggingFace](https://huggingface.co/modelscope/Nexus-GenV2)
|
| 96 |
+
- 训练数据集: [ModelScope Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Nexus-Gen-Training-Dataset)
|
| 97 |
+
- 在线体验: [ModelScope Nexus-Gen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/Nexus-Gen)
|
| 98 |
+
|
| 99 |
+
- **2025年6月15日** ModelScope 官方评测框架 [EvalScope](https://github.com/modelscope/evalscope) 现已支持文生图生成评测。请参考[最佳实践](https://evalscope.readthedocs.io/zh-cn/latest/best_practice/t2i_eval.html)指南进行尝试。
|
| 100 |
+
|
| 101 |
+
- **2025年3月25日** 我们的新开源项目 [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine) 现已开源!专注于稳定的模型部署,面向工业界,提供更好的工程支持、更高的计算性能和更稳定的功能。
|
| 102 |
+
|
| 103 |
+
- **2025年3月31日** 我们支持 InfiniteYou,一种用于 FLUX 的人脸特征保留方法。更多细节请参考 [./examples/InfiniteYou/](./examples/InfiniteYou/)。
|
| 104 |
+
|
| 105 |
+
- **2025年3月13日** 我们支持 HunyuanVideo-I2V,即腾讯开源的 HunyuanVideo 的图像到视频生成版本。更多细节请参考 [./examples/HunyuanVideo/](./examples/HunyuanVideo/)。
|
| 106 |
+
|
| 107 |
+
- **2025年2月25日** 我们支持 Wan-Video,这是阿里巴巴开源的一系列最先进的视频合成模型。详见 [./examples/wanvideo/](./examples/wanvideo/)。
|
| 108 |
+
|
| 109 |
+
- **2025年2月17日** 我们支持 [StepVideo](https://modelscope.cn/models/stepfun-ai/stepvideo-t2v/summary)!先进的视频合成模型!详见 [./examples/stepvideo](./examples/stepvideo/)。
|
| 110 |
+
|
| 111 |
+
- **2024年12月31日** 我们提出 EliGen,一种用于精确实体级别控制的文本到图像生成的新框架,并辅以修复融合管道,将其能力扩展到图像修复任务。EliGen 可以无缝集成现有的社区模型,如 IP-Adapter 和 In-Context LoRA,提升其通用性。更多详情,请见 [./examples/EntityControl](./examples/EntityControl/)。
|
| 112 |
+
- 论文: [EliGen: Entity-Level Controlled Image Generation with Regional Attention](https://arxiv.org/abs/2501.01097)
|
| 113 |
+
- 模型: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen), [HuggingFace](https://huggingface.co/modelscope/EliGen)
|
| 114 |
+
- 在线体验: [ModelScope EliGen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/EliGen)
|
| 115 |
+
- 训练数据集: [EliGen Train Set](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)
|
| 116 |
+
|
| 117 |
+
- **2024年12月19日** 我们为 HunyuanVideo 实现了高级显存管理,使得在 24GB 显存下可以生成分辨率为 129x720x1280 的视频,或在仅 6GB 显存下生成分辨率为 129x512x384 的视频。更多细节请参考 [./examples/HunyuanVideo/](./examples/HunyuanVideo/)。
|
| 118 |
+
|
| 119 |
+
- **2024年12月18日** 我们提出 ArtAug,一种通过合成-理解交互来改进文生图模型的方法。我们以 LoRA 格式为 FLUX.1-dev 训练了一个 ArtAug 增强模块。该模型将 Qwen2-VL-72B 的美学��解融入 FLUX.1-dev,从而提升了生成图像的质量。
|
| 120 |
+
- 论文: https://arxiv.org/abs/2412.12888
|
| 121 |
+
- 示例: https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/ArtAug
|
| 122 |
+
- 模型: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1), [HuggingFace](https://huggingface.co/ECNU-CILab/ArtAug-lora-FLUX.1dev-v1)
|
| 123 |
+
- 演示: [ModelScope](https://modelscope.cn/aigc/imageGeneration?tab=advanced&versionId=7228&modelType=LoRA&sdVersion=FLUX_1&modelUrl=modelscope%3A%2F%2FDiffSynth-Studio%2FArtAug-lora-FLUX.1dev-v1%3Frevision%3Dv1.0), HuggingFace (即将上线)
|
| 124 |
+
|
| 125 |
+
- **2024年10月25日** 我们提供了广泛的 FLUX ControlNet 支持。该项目支持许多不同的 ControlNet 模型,并且可以自由组合,即使它们的结构不同。此外,ControlNet 模型兼容高分辨率优化和分区控制技术,能够实现非常强大的可控图像生成。详见 [`./examples/ControlNet/`](./examples/ControlNet/)。
|
| 126 |
+
|
| 127 |
+
- **2024年10月8日** 我们发布了基于 CogVideoX-5B 和 ExVideo 的扩展 LoRA。您可以从 [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1) 或 [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1) 下载此模型。
|
| 128 |
+
|
| 129 |
+
- **2024年8月22日** 本项目现已支持 CogVideoX-5B。详见 [此处](/examples/video_synthesis/)。我们为这个文生视频模型提供了几个有趣的功能,包括:
|
| 130 |
+
- 文本到视频
|
| 131 |
+
- 视频编辑
|
| 132 |
+
- 自我超分
|
| 133 |
+
- 视频插帧
|
| 134 |
+
|
| 135 |
+
- **2024年8月22日** 我们实现了一个有趣的画笔功能,支持所有文生图模型。现在,您可以在 AI 的辅助下使用画笔创作惊艳的图像了!
|
| 136 |
+
- 在我们的 [WebUI](#usage-in-webui) 中使用它。
|
| 137 |
+
|
| 138 |
+
- **2024年8月21日** DiffSynth-Studio 现已支持 FLUX。
|
| 139 |
+
- 启用 CFG 和高分辨率修复以提升视觉质量。详见 [此处](/examples/image_synthesis/README.md)
|
| 140 |
+
- LoRA、ControlNet 和其他附加模型将很快推出。
|
| 141 |
+
|
| 142 |
+
- **2024年6月21日** 我们提出 ExVideo,一种旨在增强视频生成模型能力的后训练微调技术。我们将 Stable Video Diffusion 进行了扩展,实现了长达 128 帧的长视频生成。
|
| 143 |
+
- [项目页面](https://ecnu-cilab.github.io/ExVideoProjectPage/)
|
| 144 |
+
- 源代码已在此仓库中发布。详见 [`examples/ExVideo`](./examples/ExVideo/)。
|
| 145 |
+
- 模型已发布于 [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1) 和 [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1)。
|
| 146 |
+
- 技术报告已发布于 [arXiv](https://arxiv.org/abs/2406.14130)。
|
| 147 |
+
- 您可以在此 [演示](https://huggingface.co/spaces/modelscope/ExVideo-SVD-128f-v1) 中试用 ExVideo!
|
| 148 |
+
|
| 149 |
+
- **2024年6月13日** DiffSynth Studio 已迁移至 ModelScope。开发团队也从“我”转变为“我们”。当然,我仍会参与后续的开发和维护工作。
|
| 150 |
+
|
| 151 |
+
- **2024年1月29日** 我们提出 Diffutoon,这是一个出色的卡通着色解决方案。
|
| 152 |
+
- [项目页面](https://ecnu-cilab.github.io/DiffutoonProjectPage/)
|
| 153 |
+
- 源代码已在此项目中发布。
|
| 154 |
+
- 技术报告(IJCAI 2024)已发布于 [arXiv](https://arxiv.org/abs/2401.16224)。
|
| 155 |
+
|
| 156 |
+
- **2023年12月8日** 我们决定启动一个新项目,旨在释放扩散模型的潜力,尤其是在视频合成方面。该项目的开发工作正式开始。
|
| 157 |
+
|
| 158 |
+
- **2023年11月15日** 我们提出 FastBlend,一种强大的视频去闪烁算法。
|
| 159 |
+
- sd-webui 扩展已发布于 [GitHub](https://github.com/Artiprocher/sd-webui-fastblend)。
|
| 160 |
+
- 演示视频已在 Bilibili 上展示,包含三个任务:
|
| 161 |
+
- [视频去闪烁](https://www.bilibili.com/video/BV1d94y1W7PE)
|
| 162 |
+
- [视频插帧](https://www.bilibili.com/video/BV1Lw411m71p)
|
| 163 |
+
- [图像驱动的视频渲染](https://www.bilibili.com/video/BV1RB4y1Z7LF)
|
| 164 |
+
- 技术报告已发布于 [arXiv](https://arxiv.org/abs/2311.09265)。
|
| 165 |
+
- 其他用户开发的非官方 ComfyUI 扩展已发布于 [GitHub](https://github.com/AInseven/ComfyUI-fastblend)。
|
| 166 |
+
|
| 167 |
+
- **2023年10月1日** 我们发布了该项目的早期版本,名为 FastSDXL。这是构建一个扩散引擎的初步尝试。
|
| 168 |
+
- 源代码已发布于 [GitHub](https://github.com/Artiprocher/FastSDXL)。
|
| 169 |
+
- FastSDXL 包含一个可训练的 OLSS 调度器,以提高效率。
|
| 170 |
+
- OLSS 的原始仓库位于 [此处](https://github.com/alibaba/EasyNLP/tree/master/diffusion/olss_scheduler)。
|
| 171 |
+
- 技术报告(CIKM 2023)已发布于 [arXiv](https://arxiv.org/abs/2305.14677)。
|
| 172 |
+
- 演示视频已发布于 [Bilibili](https://www.bilibili.com/video/BV1w8411y7uj)。
|
| 173 |
+
- 由于 OLSS 需要额外训练,我们未在本项目中实现它。
|
| 174 |
+
|
| 175 |
+
- **2023年8月29日** 我们提出 DiffSynth,一个视频合成框架。
|
| 176 |
+
- [项目页面](https://ecnu-cilab.github.io/DiffSynth.github.io/)。
|
| 177 |
+
- 源代码已发布在 [EasyNLP](https://github.com/alibaba/EasyNLP/tree/master/diffusion/DiffSynth)。
|
| 178 |
+
- 技术报告(ECML PKDD 2024)已发布于 [arXiv](https://arxiv.org/abs/2308.03463)。
|
| 179 |
+
|
| 180 |
+
</details>
|
| 181 |
+
|
| 182 |
+
## 安装
|
| 183 |
+
|
| 184 |
+
从源码安装(推荐):
|
| 185 |
+
|
| 186 |
+
```
|
| 187 |
+
git clone https://github.com/modelscope/DiffSynth-Studio.git
|
| 188 |
+
cd DiffSynth-Studio
|
| 189 |
+
pip install -e .
|
| 190 |
+
```
|
| 191 |
+
|
| 192 |
+
<details>
|
| 193 |
+
<summary>其他安装方式</summary>
|
| 194 |
+
|
| 195 |
+
从 pypi 安装(存在版本更新延迟,如需使用最新功能,请从源码安装)
|
| 196 |
+
|
| 197 |
+
```
|
| 198 |
+
pip install diffsynth
|
| 199 |
+
```
|
| 200 |
+
|
| 201 |
+
如果在安装过程中遇到问题,可能是由上游依赖包导致的,请参考这些包的文档:
|
| 202 |
+
|
| 203 |
+
* [torch](https://pytorch.org/get-started/locally/)
|
| 204 |
+
* [sentencepiece](https://github.com/google/sentencepiece)
|
| 205 |
+
* [cmake](https://cmake.org)
|
| 206 |
+
* [cupy](https://docs.cupy.dev/en/stable/install.html)
|
| 207 |
+
|
| 208 |
+
</details>
|
| 209 |
+
|
| 210 |
+
## 基础框架
|
| 211 |
+
|
| 212 |
+
DiffSynth-Studio 为主流 Diffusion 模型(包括 FLUX、Wan 等)重新设计了推理和训练流水线,能够实现高效的显存管理、灵活的模型训练。
|
| 213 |
+
|
| 214 |
+
<details>
|
| 215 |
+
<summary>环境变量配置</summary>
|
| 216 |
+
|
| 217 |
+
> 在进行模型推理和训练前,可通过[环境变量](/docs/zh/Pipeline_Usage/Environment_Variables.md)配置模型下载源等。
|
| 218 |
+
>
|
| 219 |
+
> 本项目默认从魔搭社区下载模型。对于非中国区域的用户,可以通过以下配置从魔搭社区的国际站下载模型:
|
| 220 |
+
>
|
| 221 |
+
> ```python
|
| 222 |
+
> import os
|
| 223 |
+
> os.environ["MODELSCOPE_DOMAIN"] = "www.modelscope.ai"
|
| 224 |
+
> ```
|
| 225 |
+
>
|
| 226 |
+
> 如需从其他站点下载,请修改[环境变量 DIFFSYNTH_DOWNLOAD_SOURCE](/docs/zh/Pipeline_Usage/Environment_Variables.md#diffsynth_download_source)。
|
| 227 |
+
|
| 228 |
+
</details>
|
| 229 |
+
|
| 230 |
+
### 图像生成模型
|
| 231 |
+
|
| 232 |
+

|
| 233 |
+
|
| 234 |
+
#### Z-Image:[/docs/zh/Model_Details/Z-Image.md](/docs/zh/Model_Details/Z-Image.md)
|
| 235 |
+
|
| 236 |
+
<details>
|
| 237 |
+
|
| 238 |
+
<summary>快速开始</summary>
|
| 239 |
+
|
| 240 |
+
运行以下代码可以快速加载 [Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo) 模型并进行推理。FP8 精度量化会导致明显的图像质量劣化,因此不建议在 Z-Image Turbo 模型上开启任何量化,仅建议开启 CPU Offload,最低 8G 显存即可运行。
|
| 241 |
+
|
| 242 |
+
```python
|
| 243 |
+
from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig
|
| 244 |
+
import torch
|
| 245 |
+
|
| 246 |
+
vram_config = {
|
| 247 |
+
"offload_dtype": torch.bfloat16,
|
| 248 |
+
"offload_device": "cpu",
|
| 249 |
+
"onload_dtype": torch.bfloat16,
|
| 250 |
+
"onload_device": "cpu",
|
| 251 |
+
"preparing_dtype": torch.bfloat16,
|
| 252 |
+
"preparing_device": "cuda",
|
| 253 |
+
"computation_dtype": torch.bfloat16,
|
| 254 |
+
"computation_device": "cuda",
|
| 255 |
+
}
|
| 256 |
+
pipe = ZImagePipeline.from_pretrained(
|
| 257 |
+
torch_dtype=torch.bfloat16,
|
| 258 |
+
device="cuda",
|
| 259 |
+
model_configs=[
|
| 260 |
+
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors", **vram_config),
|
| 261 |
+
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
|
| 262 |
+
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
| 263 |
+
],
|
| 264 |
+
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
|
| 265 |
+
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
| 266 |
+
)
|
| 267 |
+
prompt = "Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights."
|
| 268 |
+
image = pipe(prompt=prompt, seed=42, rand_device="cuda")
|
| 269 |
+
image.save("image.jpg")
|
| 270 |
+
```
|
| 271 |
+
|
| 272 |
+
</details>
|
| 273 |
+
|
| 274 |
+
<details>
|
| 275 |
+
|
| 276 |
+
<summary>示例代码</summary>
|
| 277 |
+
|
| 278 |
+
Z-Image 的示例代码位于:[/examples/z_image/](/examples/z_image/)
|
| 279 |
+
|
| 280 |
+
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
| 281 |
+
|-|-|-|-|-|-|-|
|
| 282 |
+
|[Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo)|[code](/examples/z_image/model_inference/Z-Image-Turbo.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py)|
|
| 283 |
+
|
| 284 |
+
</details>
|
| 285 |
+
|
| 286 |
+
#### FLUX.2: [/docs/zh/Model_Details/FLUX2.md](/docs/zh/Model_Details/FLUX2.md)
|
| 287 |
+
|
| 288 |
+
<details>
|
| 289 |
+
|
| 290 |
+
<summary>快速开始</summary>
|
| 291 |
+
|
| 292 |
+
运行以下代码可以快速加载 [black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 10G 显存即可运行。
|
| 293 |
+
|
| 294 |
+
```python
|
| 295 |
+
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
|
| 296 |
+
import torch
|
| 297 |
+
|
| 298 |
+
vram_config = {
|
| 299 |
+
"offload_dtype": "disk",
|
| 300 |
+
"offload_device": "disk",
|
| 301 |
+
"onload_dtype": torch.float8_e4m3fn,
|
| 302 |
+
"onload_device": "cpu",
|
| 303 |
+
"preparing_dtype": torch.float8_e4m3fn,
|
| 304 |
+
"preparing_device": "cuda",
|
| 305 |
+
"computation_dtype": torch.bfloat16,
|
| 306 |
+
"computation_device": "cuda",
|
| 307 |
+
}
|
| 308 |
+
pipe = Flux2ImagePipeline.from_pretrained(
|
| 309 |
+
torch_dtype=torch.bfloat16,
|
| 310 |
+
device="cuda",
|
| 311 |
+
model_configs=[
|
| 312 |
+
ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
|
| 313 |
+
ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="transformer/*.safetensors", **vram_config),
|
| 314 |
+
ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
| 315 |
+
],
|
| 316 |
+
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="tokenizer/"),
|
| 317 |
+
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
| 318 |
+
)
|
| 319 |
+
prompt = "High resolution. A dreamy underwater portrait of a serene young woman in a flowing blue dress. Her hair floats softly around her face, strands delicately suspended in the water. Clear, shimmering light filters through, casting gentle highlights, while tiny bubbles rise around her. Her expression is calm, her features finely detailed—creating a tranquil, ethereal scene."
|
| 320 |
+
image = pipe(prompt, seed=42, rand_device="cuda", num_inference_steps=50)
|
| 321 |
+
image.save("image.jpg")
|
| 322 |
+
```
|
| 323 |
+
|
| 324 |
+
</details>
|
| 325 |
+
|
| 326 |
+
<details>
|
| 327 |
+
|
| 328 |
+
<summary>示例代码</summary>
|
| 329 |
+
|
| 330 |
+
FLUX.2 的示例代码位于:[/examples/flux2/](/examples/flux2/)
|
| 331 |
+
|
| 332 |
+
|模型 ID|推理|低显存推理|LoRA 训练|LoRA 训练后验证|
|
| 333 |
+
|-|-|-|-|-|
|
| 334 |
+
|[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](/examples/flux2/model_inference/FLUX.2-dev.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|[code](/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)|
|
| 335 |
+
|
| 336 |
+
</details>
|
| 337 |
+
|
| 338 |
+
#### Qwen-Image: [/docs/zh/Model_Details/Qwen-Image.md](/docs/zh/Model_Details/Qwen-Image.md)
|
| 339 |
+
|
| 340 |
+
<details>
|
| 341 |
+
|
| 342 |
+
<summary>快速开始</summary>
|
| 343 |
+
|
| 344 |
+
运行以下代码可以快速加载 [Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 8G 显存即可运行。
|
| 345 |
+
|
| 346 |
+
```python
|
| 347 |
+
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
| 348 |
+
import torch
|
| 349 |
+
|
| 350 |
+
vram_config = {
|
| 351 |
+
"offload_dtype": "disk",
|
| 352 |
+
"offload_device": "disk",
|
| 353 |
+
"onload_dtype": torch.float8_e4m3fn,
|
| 354 |
+
"onload_device": "cpu",
|
| 355 |
+
"preparing_dtype": torch.float8_e4m3fn,
|
| 356 |
+
"preparing_device": "cuda",
|
| 357 |
+
"computation_dtype": torch.bfloat16,
|
| 358 |
+
"computation_device": "cuda",
|
| 359 |
+
}
|
| 360 |
+
pipe = QwenImagePipeline.from_pretrained(
|
| 361 |
+
torch_dtype=torch.bfloat16,
|
| 362 |
+
device="cuda",
|
| 363 |
+
model_configs=[
|
| 364 |
+
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
|
| 365 |
+
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config),
|
| 366 |
+
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
| 367 |
+
],
|
| 368 |
+
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
| 369 |
+
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
| 370 |
+
)
|
| 371 |
+
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
|
| 372 |
+
image = pipe(prompt, seed=0, num_inference_steps=40)
|
| 373 |
+
image.save("image.jpg")
|
| 374 |
+
```
|
| 375 |
+
|
| 376 |
+
</details>
|
| 377 |
+
|
| 378 |
+
<details>
|
| 379 |
+
|
| 380 |
+
<summary>模型血缘</summary>
|
| 381 |
+
|
| 382 |
+
```mermaid
|
| 383 |
+
graph LR;
|
| 384 |
+
Qwen/Qwen-Image-->Qwen/Qwen-Image-Edit;
|
| 385 |
+
Qwen/Qwen-Image-Edit-->Qwen/Qwen-Image-Edit-2509;
|
| 386 |
+
Qwen/Qwen-Image-->EliGen-Series;
|
| 387 |
+
EliGen-Series-->DiffSynth-Studio/Qwen-Image-EliGen;
|
| 388 |
+
DiffSynth-Studio/Qwen-Image-EliGen-->DiffSynth-Studio/Qwen-Image-EliGen-V2;
|
| 389 |
+
EliGen-Series-->DiffSynth-Studio/Qwen-Image-EliGen-Poster;
|
| 390 |
+
Qwen/Qwen-Image-->Distill-Series;
|
| 391 |
+
Distill-Series-->DiffSynth-Studio/Qwen-Image-Distill-Full;
|
| 392 |
+
Distill-Series-->DiffSynth-Studio/Qwen-Image-Distill-LoRA;
|
| 393 |
+
Qwen/Qwen-Image-->ControlNet-Series;
|
| 394 |
+
ControlNet-Series-->Blockwise-ControlNet-Series;
|
| 395 |
+
Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny;
|
| 396 |
+
Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth;
|
| 397 |
+
Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint;
|
| 398 |
+
ControlNet-Series-->DiffSynth-Studio/Qwen-Image-In-Context-Control-Union;
|
| 399 |
+
Qwen/Qwen-Image-->DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix;
|
| 400 |
+
```
|
| 401 |
+
|
| 402 |
+
</details>
|
| 403 |
+
|
| 404 |
+
<details>
|
| 405 |
+
|
| 406 |
+
<summary>示例代码</summary>
|
| 407 |
+
|
| 408 |
+
Qwen-Image 的示例代码位于:[/examples/qwen_image/](/examples/qwen_image/)
|
| 409 |
+
|
| 410 |
+
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
| 411 |
+
|-|-|-|-|-|-|-|
|
| 412 |
+
|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](/examples/qwen_image/model_inference/Qwen-Image.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image.py)|
|
| 413 |
+
|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)|
|
| 414 |
+
|[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)|
|
| 415 |
+
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
| 416 |
+
|[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
| 417 |
+
|[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)|
|
| 418 |
+
|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](/examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-Full.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py)|
|
| 419 |
+
|[DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA)|[code](/examples/qwen_image/model_inference/Qwen-Image-Distill-LoRA.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-LoRA.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-LoRA.py)|
|
| 420 |
+
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py)|
|
| 421 |
+
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py)|
|
| 422 |
+
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|
|
| 423 |
+
|[DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union)|[code](/examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py)|
|
| 424 |
+
|[DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py)|-|-|-|-|
|
| 425 |
+
|[DiffSynth-Studio/Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L)|[code](/examples/qwen_image/model_inference/Qwen-Image-i2L.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-i2L.py)|-|-|-|-|
|
| 426 |
+
|
| 427 |
+
</details>
|
| 428 |
+
|
| 429 |
+
#### FLUX.1: [/docs/zh/Model_Details/FLUX.md](/docs/zh/Model_Details/FLUX.md)
|
| 430 |
+
|
| 431 |
+
<details>
|
| 432 |
+
|
| 433 |
+
<summary>快速开始</summary>
|
| 434 |
+
|
| 435 |
+
运行以下代码可以快速加载 [black-forest-labs/FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 8G 显存即可运行。
|
| 436 |
+
|
| 437 |
+
```python
|
| 438 |
+
import torch
|
| 439 |
+
from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig
|
| 440 |
+
|
| 441 |
+
vram_config = {
|
| 442 |
+
"offload_dtype": torch.float8_e4m3fn,
|
| 443 |
+
"offload_device": "cpu",
|
| 444 |
+
"onload_dtype": torch.float8_e4m3fn,
|
| 445 |
+
"onload_device": "cpu",
|
| 446 |
+
"preparing_dtype": torch.float8_e4m3fn,
|
| 447 |
+
"preparing_device": "cuda",
|
| 448 |
+
"computation_dtype": torch.bfloat16,
|
| 449 |
+
"computation_device": "cuda",
|
| 450 |
+
}
|
| 451 |
+
pipe = FluxImagePipeline.from_pretrained(
|
| 452 |
+
torch_dtype=torch.bfloat16,
|
| 453 |
+
device="cuda",
|
| 454 |
+
model_configs=[
|
| 455 |
+
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors", **vram_config),
|
| 456 |
+
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
|
| 457 |
+
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors", **vram_config),
|
| 458 |
+
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", **vram_config),
|
| 459 |
+
],
|
| 460 |
+
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 1,
|
| 461 |
+
)
|
| 462 |
+
prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her."
|
| 463 |
+
image = pipe(prompt=prompt, seed=0)
|
| 464 |
+
image.save("image.jpg")
|
| 465 |
+
```
|
| 466 |
+
|
| 467 |
+
</details>
|
| 468 |
+
|
| 469 |
+
<details>
|
| 470 |
+
|
| 471 |
+
<summary>模型血缘</summary>
|
| 472 |
+
|
| 473 |
+
```mermaid
|
| 474 |
+
graph LR;
|
| 475 |
+
FLUX.1-Series-->black-forest-labs/FLUX.1-dev;
|
| 476 |
+
FLUX.1-Series-->black-forest-labs/FLUX.1-Krea-dev;
|
| 477 |
+
FLUX.1-Series-->black-forest-labs/FLUX.1-Kontext-dev;
|
| 478 |
+
black-forest-labs/FLUX.1-dev-->FLUX.1-dev-ControlNet-Series;
|
| 479 |
+
FLUX.1-dev-ControlNet-Series-->alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta;
|
| 480 |
+
FLUX.1-dev-ControlNet-Series-->InstantX/FLUX.1-dev-Controlnet-Union-alpha;
|
| 481 |
+
FLUX.1-dev-ControlNet-Series-->jasperai/Flux.1-dev-Controlnet-Upscaler;
|
| 482 |
+
black-forest-labs/FLUX.1-dev-->InstantX/FLUX.1-dev-IP-Adapter;
|
| 483 |
+
black-forest-labs/FLUX.1-dev-->ByteDance/InfiniteYou;
|
| 484 |
+
black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/Eligen;
|
| 485 |
+
black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev;
|
| 486 |
+
black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev;
|
| 487 |
+
black-forest-labs/FLUX.1-dev-->ostris/Flex.2-preview;
|
| 488 |
+
black-forest-labs/FLUX.1-dev-->stepfun-ai/Step1X-Edit;
|
| 489 |
+
Qwen/Qwen2.5-VL-7B-Instruct-->stepfun-ai/Step1X-Edit;
|
| 490 |
+
black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/Nexus-GenV2;
|
| 491 |
+
Qwen/Qwen2.5-VL-7B-Instruct-->DiffSynth-Studio/Nexus-GenV2;
|
| 492 |
+
```
|
| 493 |
+
|
| 494 |
+
</details>
|
| 495 |
+
|
| 496 |
+
<details>
|
| 497 |
+
|
| 498 |
+
<summary>示例代码</summary>
|
| 499 |
+
|
| 500 |
+
FLUX.1 的示例代码位于:[/examples/flux/](/examples/flux/)
|
| 501 |
+
|
| 502 |
+
|模型 ID|额外参数|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
| 503 |
+
|-|-|-|-|-|-|-|-|
|
| 504 |
+
|[black-forest-labs/FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev)||[code](/examples/flux/model_inference/FLUX.1-dev.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev.py)|[code](/examples/flux/model_training/full/FLUX.1-dev.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev.py)|
|
| 505 |
+
|[black-forest-labs/FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev)||[code](/examples/flux/model_inference/FLUX.1-Krea-dev.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-Krea-dev.py)|[code](/examples/flux/model_training/full/FLUX.1-Krea-dev.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-Krea-dev.py)|[code](/examples/flux/model_training/lora/FLUX.1-Krea-dev.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-Krea-dev.py)|
|
| 506 |
+
|[black-forest-labs/FLUX.1-Kontext-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev)|`kontext_images`|[code](/examples/flux/model_inference/FLUX.1-Kontext-dev.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-Kontext-dev.py)|[code](/examples/flux/model_training/full/FLUX.1-Kontext-dev.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-Kontext-dev.py)|[code](/examples/flux/model_training/lora/FLUX.1-Kontext-dev.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-Kontext-dev.py)|
|
| 507 |
+
|[alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta](https://www.modelscope.cn/models/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta)|`controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|
|
| 508 |
+
|[InstantX/FLUX.1-dev-Controlnet-Union-alpha](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-Controlnet-Union-alpha)|`controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Union-alpha.py)|
|
| 509 |
+
|[jasperai/Flux.1-dev-Controlnet-Upscaler](https://www.modelscope.cn/models/jasperai/Flux.1-dev-Controlnet-Upscaler)|`controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Upscaler.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Upscaler.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Upscaler.py)|
|
| 510 |
+
|[InstantX/FLUX.1-dev-IP-Adapter](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-IP-Adapter)|`ipadapter_images`, `ipadapter_scale`|[code](/examples/flux/model_inference/FLUX.1-dev-IP-Adapter.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-IP-Adapter.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-IP-Adapter.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-IP-Adapter.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-IP-Adapter.py)|
|
| 511 |
+
|[ByteDance/InfiniteYou](https://www.modelscope.cn/models/ByteDance/InfiniteYou)|`infinityou_id_image`, `infinityou_guidance`, `controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-InfiniteYou.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-InfiniteYou.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-InfiniteYou.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-InfiniteYou.py)|
|
| 512 |
+
|[DiffSynth-Studio/Eligen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen)|`eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint`|[code](/examples/flux/model_inference/FLUX.1-dev-EliGen.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py)|-|-|[code](/examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-EliGen.py)|
|
| 513 |
+
|[DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev](https://www.modelscope.cn/models/DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev)|`lora_encoder_inputs`, `lora_encoder_scale`|[code](/examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-LoRA-Encoder.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py)|-|-|
|
| 514 |
+
|[DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev)||[code](/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py)|-|-|-|-|-|
|
| 515 |
+
|[stepfun-ai/Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit)|`step1x_reference_image`|[code](/examples/flux/model_inference/Step1X-Edit.py)|[code](/examples/flux/model_inference_low_vram/Step1X-Edit.py)|[code](/examples/flux/model_training/full/Step1X-Edit.sh)|[code](/examples/flux/model_training/validate_full/Step1X-Edit.py)|[code](/examples/flux/model_training/lora/Step1X-Edit.sh)|[code](/examples/flux/model_training/validate_lora/Step1X-Edit.py)|
|
| 516 |
+
|[ostris/Flex.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview)|`flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop`|[code](/examples/flux/model_inference/FLEX.2-preview.py)|[code](/examples/flux/model_inference_low_vram/FLEX.2-preview.py)|[code](/examples/flux/model_training/full/FLEX.2-preview.sh)|[code](/examples/flux/model_training/validate_full/FLEX.2-preview.py)|[code](/examples/flux/model_training/lora/FLEX.2-preview.sh)|[code](/examples/flux/model_training/validate_lora/FLEX.2-preview.py)|
|
| 517 |
+
|[DiffSynth-Studio/Nexus-GenV2](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2)|`nexus_gen_reference_image`|[code](/examples/flux/model_inference/Nexus-Gen-Editing.py)|[code](/examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py)|[code](/examples/flux/model_training/full/Nexus-Gen.sh)|[code](/examples/flux/model_training/validate_full/Nexus-Gen.py)|[code](/examples/flux/model_training/lora/Nexus-Gen.sh)|[code](/examples/flux/model_training/validate_lora/Nexus-Gen.py)|
|
| 518 |
+
|
| 519 |
+
</details>
|
| 520 |
+
|
| 521 |
+
### 视频生成模型
|
| 522 |
+
|
| 523 |
+
https://github.com/user-attachments/assets/1d66ae74-3b02-40a9-acc3-ea95fc039314
|
| 524 |
+
|
| 525 |
+
#### Wan: [/docs/zh/Model_Details/Wan.md](/docs/zh/Model_Details/Wan.md)
|
| 526 |
+
|
| 527 |
+
<details>
|
| 528 |
+
|
| 529 |
+
<summary>快速开始</summary>
|
| 530 |
+
|
| 531 |
+
运行以下代码可以快速加载 [Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 8G 显存即可运行。
|
| 532 |
+
|
| 533 |
+
```python
|
| 534 |
+
import torch
|
| 535 |
+
from diffsynth.utils.data import save_video, VideoData
|
| 536 |
+
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
|
| 537 |
+
|
| 538 |
+
vram_config = {
|
| 539 |
+
"offload_dtype": "disk",
|
| 540 |
+
"offload_device": "disk",
|
| 541 |
+
"onload_dtype": torch.bfloat16,
|
| 542 |
+
"onload_device": "cpu",
|
| 543 |
+
"preparing_dtype": torch.bfloat16,
|
| 544 |
+
"preparing_device": "cuda",
|
| 545 |
+
"computation_dtype": torch.bfloat16,
|
| 546 |
+
"computation_device": "cuda",
|
| 547 |
+
}
|
| 548 |
+
pipe = WanVideoPipeline.from_pretrained(
|
| 549 |
+
torch_dtype=torch.bfloat16,
|
| 550 |
+
device="cuda",
|
| 551 |
+
model_configs=[
|
| 552 |
+
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", **vram_config),
|
| 553 |
+
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", **vram_config),
|
| 554 |
+
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", **vram_config),
|
| 555 |
+
],
|
| 556 |
+
tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"),
|
| 557 |
+
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2,
|
| 558 |
+
)
|
| 559 |
+
|
| 560 |
+
video = pipe(
|
| 561 |
+
prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
|
| 562 |
+
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
| 563 |
+
seed=0, tiled=True,
|
| 564 |
+
)
|
| 565 |
+
save_video(video, "video.mp4", fps=15, quality=5)
|
| 566 |
+
```
|
| 567 |
+
|
| 568 |
+
</details>
|
| 569 |
+
|
| 570 |
+
<details>
|
| 571 |
+
|
| 572 |
+
<summary>模型血缘</summary>
|
| 573 |
+
|
| 574 |
+
```mermaid
|
| 575 |
+
graph LR;
|
| 576 |
+
Wan-Series-->Wan2.1-Series;
|
| 577 |
+
Wan-Series-->Wan2.2-Series;
|
| 578 |
+
Wan2.1-Series-->Wan-AI/Wan2.1-T2V-1.3B;
|
| 579 |
+
Wan2.1-Series-->Wan-AI/Wan2.1-T2V-14B;
|
| 580 |
+
Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-I2V-14B-480P;
|
| 581 |
+
Wan-AI/Wan2.1-I2V-14B-480P-->Wan-AI/Wan2.1-I2V-14B-720P;
|
| 582 |
+
Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-FLF2V-14B-720P;
|
| 583 |
+
Wan-AI/Wan2.1-T2V-1.3B-->iic/VACE-Wan2.1-1.3B-Preview;
|
| 584 |
+
iic/VACE-Wan2.1-1.3B-Preview-->Wan-AI/Wan2.1-VACE-1.3B;
|
| 585 |
+
Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-VACE-14B;
|
| 586 |
+
Wan-AI/Wan2.1-T2V-1.3B-->Wan2.1-Fun-1.3B-Series;
|
| 587 |
+
Wan2.1-Fun-1.3B-Series-->PAI/Wan2.1-Fun-1.3B-InP;
|
| 588 |
+
Wan2.1-Fun-1.3B-Series-->PAI/Wan2.1-Fun-1.3B-Control;
|
| 589 |
+
Wan-AI/Wan2.1-T2V-14B-->Wan2.1-Fun-14B-Series;
|
| 590 |
+
Wan2.1-Fun-14B-Series-->PAI/Wan2.1-Fun-14B-InP;
|
| 591 |
+
Wan2.1-Fun-14B-Series-->PAI/Wan2.1-Fun-14B-Control;
|
| 592 |
+
Wan-AI/Wan2.1-T2V-1.3B-->Wan2.1-Fun-V1.1-1.3B-Series;
|
| 593 |
+
Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-Control;
|
| 594 |
+
Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-InP;
|
| 595 |
+
Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera;
|
| 596 |
+
Wan-AI/Wan2.1-T2V-14B-->Wan2.1-Fun-V1.1-14B-Series;
|
| 597 |
+
Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-Control;
|
| 598 |
+
Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-InP;
|
| 599 |
+
Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-Control-Camera;
|
| 600 |
+
Wan-AI/Wan2.1-T2V-1.3B-->DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1;
|
| 601 |
+
Wan-AI/Wan2.1-T2V-14B-->krea/krea-realtime-video;
|
| 602 |
+
Wan-AI/Wan2.1-T2V-14B-->meituan-longcat/LongCat-Video;
|
| 603 |
+
Wan-AI/Wan2.1-I2V-14B-720P-->ByteDance/Video-As-Prompt-Wan2.1-14B;
|
| 604 |
+
Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.2-Animate-14B;
|
| 605 |
+
Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.2-S2V-14B;
|
| 606 |
+
Wan2.2-Series-->Wan-AI/Wan2.2-T2V-A14B;
|
| 607 |
+
Wan2.2-Series-->Wan-AI/Wan2.2-I2V-A14B;
|
| 608 |
+
Wan2.2-Series-->Wan-AI/Wan2.2-TI2V-5B;
|
| 609 |
+
Wan-AI/Wan2.2-T2V-A14B-->Wan2.2-Fun-Series;
|
| 610 |
+
Wan2.2-Fun-Series-->PAI/Wan2.2-VACE-Fun-A14B;
|
| 611 |
+
Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-InP;
|
| 612 |
+
Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-Control;
|
| 613 |
+
Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-Control-Camera;
|
| 614 |
+
```
|
| 615 |
+
|
| 616 |
+
</details>
|
| 617 |
+
|
| 618 |
+
<details>
|
| 619 |
+
|
| 620 |
+
<summary>示例代码</summary>
|
| 621 |
+
|
| 622 |
+
Wan 的示例代码位于:[/examples/wanvideo/](/examples/wanvideo/)
|
| 623 |
+
|
| 624 |
+
|模型 ID|额外参数|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
| 625 |
+
|-|-|-|-|-|-|-|
|
| 626 |
+
|[Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)||[code](/examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py)|
|
| 627 |
+
|[Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)||[code](/examples/wanvideo/model_inference/Wan2.1-T2V-14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py)|
|
| 628 |
+
|[Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py)|
|
| 629 |
+
|[Wan-AI/Wan2.1-I2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-720P.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-720P.py)|
|
| 630 |
+
|[Wan-AI/Wan2.1-FLF2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-FLF2V-14B-720P.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py)|
|
| 631 |
+
|[iic/VACE-Wan2.1-1.3B-Preview](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B-Preview.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py)|
|
| 632 |
+
|[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B.py)|
|
| 633 |
+
|[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-VACE-14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py)|
|
| 634 |
+
|[PAI/Wan2.1-Fun-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py)|
|
| 635 |
+
|[PAI/Wan2.1-Fun-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control)|`control_video`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py)|
|
| 636 |
+
|[PAI/Wan2.1-Fun-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-InP.py)|
|
| 637 |
+
|[PAI/Wan2.1-Fun-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control)|`control_video`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-Control.py)|
|
| 638 |
+
|[PAI/Wan2.1-Fun-V1.1-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py)|
|
| 639 |
+
|[PAI/Wan2.1-Fun-V1.1-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py)|
|
| 640 |
+
|[PAI/Wan2.1-Fun-V1.1-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py)|
|
| 641 |
+
|[PAI/Wan2.1-Fun-V1.1-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py)|
|
| 642 |
+
|[PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|
|
| 643 |
+
|[PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|
|
| 644 |
+
|[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](/examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)|
|
| 645 |
+
|[krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video)||[code](/examples/wanvideo/model_inference/krea-realtime-video.py)|[code](/examples/wanvideo/model_training/full/krea-realtime-video.sh)|[code](/examples/wanvideo/model_training/validate_full/krea-realtime-video.py)|[code](/examples/wanvideo/model_training/lora/krea-realtime-video.sh)|[code](/examples/wanvideo/model_training/validate_lora/krea-realtime-video.py)|
|
| 646 |
+
|[meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video)|`longcat_video`|[code](/examples/wanvideo/model_inference/LongCat-Video.py)|[code](/examples/wanvideo/model_training/full/LongCat-Video.sh)|[code](/examples/wanvideo/model_training/validate_full/LongCat-Video.py)|[code](/examples/wanvideo/model_training/lora/LongCat-Video.sh)|[code](/examples/wanvideo/model_training/validate_lora/LongCat-Video.py)|
|
| 647 |
+
|[ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B)|`vap_video`, `vap_prompt`|[code](/examples/wanvideo/model_inference/Video-As-Prompt-Wan2.1-14B.py)|[code](/examples/wanvideo/model_training/full/Video-As-Prompt-Wan2.1-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Video-As-Prompt-Wan2.1-14B.py)|[code](/examples/wanvideo/model_training/lora/Video-As-Prompt-Wan2.1-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Video-As-Prompt-Wan2.1-14B.py)|
|
| 648 |
+
|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)|
|
| 649 |
+
|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)|
|
| 650 |
+
|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py)|
|
| 651 |
+
|[Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B)|`input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video`|[code](/examples/wanvideo/model_inference/Wan2.2-Animate-14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py)|
|
| 652 |
+
|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](/examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-S2V-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-S2V-14B.py)|
|
| 653 |
+
|[PAI/Wan2.2-VACE-Fun-A14B](https://www.modelscope.cn/models/PAI/Wan2.2-VACE-Fun-A14B)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.2-VACE-Fun-A14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-VACE-Fun-A14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-VACE-Fun-A14B.py)|
|
| 654 |
+
|[PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py)|
|
| 655 |
+
|[PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py)|
|
| 656 |
+
|[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)|
|
| 657 |
+
|
| 658 |
+
</details>
|
| 659 |
+
|
| 660 |
+
## 创新成果
|
| 661 |
+
|
| 662 |
+
DiffSynth-Studio 不仅仅是一个工程化的模型框架,更是创新成果的孵化器。
|
| 663 |
+
|
| 664 |
+
<details>
|
| 665 |
+
|
| 666 |
+
<summary>AttriCtrl: 图像生成模型的属性强度控制</summary>
|
| 667 |
+
|
| 668 |
+
- 论文:[AttriCtrl: Fine-Grained Control of Aesthetic Attribute Intensity in Diffusion Models
|
| 669 |
+
](https://arxiv.org/abs/2508.02151)
|
| 670 |
+
- 代码样例:[/examples/flux/model_inference/FLUX.1-dev-AttriCtrl.py](/examples/flux/model_inference/FLUX.1-dev-AttriCtrl.py)
|
| 671 |
+
- 模型:[ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev)
|
| 672 |
+
|
| 673 |
+
|brightness scale = 0.1|brightness scale = 0.3|brightness scale = 0.5|brightness scale = 0.7|brightness scale = 0.9|
|
| 674 |
+
|-|-|-|-|-|
|
| 675 |
+
||||||
|
| 676 |
+
|
| 677 |
+
</details>
|
| 678 |
+
|
| 679 |
+
|
| 680 |
+
<details>
|
| 681 |
+
|
| 682 |
+
<summary>AutoLoRA: 自动化的 LoRA 检索和融合</summary>
|
| 683 |
+
|
| 684 |
+
- 论文:[AutoLoRA: Automatic LoRA Retrieval and Fine-Grained Gated Fusion for Text-to-Image Generation
|
| 685 |
+
](https://arxiv.org/abs/2508.02107)
|
| 686 |
+
- 代码样例:[/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py](/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py)
|
| 687 |
+
- 模型:[ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev)
|
| 688 |
+
|
| 689 |
+
||[LoRA 1](https://modelscope.cn/models/cancel13/cxsk)|[LoRA 2](https://modelscope.cn/models/wy413928499/xuancai2)|[LoRA 3](https://modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1)|[LoRA 4](https://modelscope.cn/models/hongyanbujian/JPL)|
|
| 690 |
+
|-|-|-|-|-|
|
| 691 |
+
|[LoRA 1](https://modelscope.cn/models/cancel13/cxsk) |||||
|
| 692 |
+
|[LoRA 2](https://modelscope.cn/models/wy413928499/xuancai2) |||||
|
| 693 |
+
|[LoRA 3](https://modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1) |||||
|
| 694 |
+
|[LoRA 4](https://modelscope.cn/models/hongyanbujian/JPL) |||||
|
| 695 |
+
|
| 696 |
+
</details>
|
| 697 |
+
|
| 698 |
+
|
| 699 |
+
<details>
|
| 700 |
+
|
| 701 |
+
<summary>Nexus-Gen: 统一架构的图像理解、生成、编辑</summary>
|
| 702 |
+
|
| 703 |
+
- 详细页面:https://github.com/modelscope/Nexus-Gen
|
| 704 |
+
- 论文:[Nexus-Gen: Unified Image Understanding, Generation, and Editing via Prefilled Autoregression in Shared Embedding Space](https://arxiv.org/pdf/2504.21356)
|
| 705 |
+
- 模型:[ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2), [HuggingFace](https://huggingface.co/modelscope/Nexus-GenV2)
|
| 706 |
+
- 数据集:[ModelScope Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Nexus-Gen-Training-Dataset)
|
| 707 |
+
- 在线体验:[ModelScope Nexus-Gen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/Nexus-Gen)
|
| 708 |
+
|
| 709 |
+

|
| 710 |
+
|
| 711 |
+
</details>
|
| 712 |
+
|
| 713 |
+
|
| 714 |
+
<details>
|
| 715 |
+
|
| 716 |
+
<summary>ArtAug: 图像生成模型的美学提升</summary>
|
| 717 |
+
|
| 718 |
+
- 详细页面:[./examples/ArtAug/](./examples/ArtAug/)
|
| 719 |
+
- 论文:[ArtAug: Enhancing Text-to-Image Generation through Synthesis-Understanding Interaction](https://arxiv.org/abs/2412.12888)
|
| 720 |
+
- 模型:[ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1), [HuggingFace](https://huggingface.co/ECNU-CILab/ArtAug-lora-FLUX.1dev-v1)
|
| 721 |
+
- 在线体验:[ModelScope AIGC Tab](https://www.modelscope.cn/aigc/imageGeneration?tab=advanced&versionId=7228&modelType=LoRA&sdVersion=FLUX_1&modelUrl=modelscope%3A%2F%2FDiffSynth-Studio%2FArtAug-lora-FLUX.1dev-v1%3Frevision%3Dv1.0)
|
| 722 |
+
|
| 723 |
+
|FLUX.1-dev|FLUX.1-dev + ArtAug LoRA|
|
| 724 |
+
|-|-|
|
| 725 |
+
|||
|
| 726 |
+
|
| 727 |
+
</details>
|
| 728 |
+
|
| 729 |
+
|
| 730 |
+
<details>
|
| 731 |
+
|
| 732 |
+
<summary>EliGen: 精准的图像分区控制</summary>
|
| 733 |
+
|
| 734 |
+
- 论文:[EliGen: Entity-Level Controlled Image Generation with Regional Attention](https://arxiv.org/abs/2501.01097)
|
| 735 |
+
- 代码样例:[/examples/flux/model_inference/FLUX.1-dev-EliGen.py](/examples/flux/model_inference/FLUX.1-dev-EliGen.py)
|
| 736 |
+
- 模型:[ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen), [HuggingFace](https://huggingface.co/modelscope/EliGen)
|
| 737 |
+
- 在线体验:[ModelScope EliGen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/EliGen)
|
| 738 |
+
- 数据集:[EliGen Train Set](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)
|
| 739 |
+
|
| 740 |
+
|实体控制区域|生成图像|
|
| 741 |
+
|-|-|
|
| 742 |
+
|||
|
| 743 |
+
|
| 744 |
+
</details>
|
| 745 |
+
|
| 746 |
+
|
| 747 |
+
<details>
|
| 748 |
+
|
| 749 |
+
<summary>ExVideo: 视频生成模型的扩展训练</summary>
|
| 750 |
+
|
| 751 |
+
- 项目页面:[Project Page](https://ecnu-cilab.github.io/ExVideoProjectPage/)
|
| 752 |
+
- 论文:[ExVideo: Extending Video Diffusion Models via Parameter-Efficient Post-Tuning](https://arxiv.org/abs/2406.14130)
|
| 753 |
+
- 代码样例:请前往[旧版本](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3/examples/ExVideo)查看
|
| 754 |
+
- 模型:[ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1), [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1)
|
| 755 |
+
|
| 756 |
+
https://github.com/modelscope/DiffSynth-Studio/assets/35051019/d97f6aa9-8064-4b5b-9d49-ed6001bb9acc
|
| 757 |
+
|
| 758 |
+
</details>
|
| 759 |
+
|
| 760 |
+
|
| 761 |
+
<details>
|
| 762 |
+
|
| 763 |
+
<summary>Diffutoon: 高分辨率动漫风格视频渲染</summary>
|
| 764 |
+
|
| 765 |
+
- 项目页面:[Project Page](https://ecnu-cilab.github.io/DiffutoonProjectPage/)
|
| 766 |
+
- 论文:[Diffutoon: High-Resolution Editable Toon Shading via Diffusion Models](https://arxiv.org/abs/2401.16224)
|
| 767 |
+
- 代码样例:请前往[旧版本](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3/examples/Diffutoon)查看
|
| 768 |
+
|
| 769 |
+
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/b54c05c5-d747-4709-be5e-b39af82404dd
|
| 770 |
+
|
| 771 |
+
</details>
|
| 772 |
+
|
| 773 |
+
|
| 774 |
+
<details>
|
| 775 |
+
|
| 776 |
+
<summary>DiffSynth: 本项目的初代版本</summary>
|
| 777 |
+
|
| 778 |
+
- 项目页面:[Project Page](https://ecnu-cilab.github.io/DiffSynth.github.io/)
|
| 779 |
+
- 论文:[DiffSynth: Latent In-Iteration Deflickering for Realistic Video Synthesis](https://arxiv.org/abs/2308.03463)
|
| 780 |
+
- 代码样例:请前往[旧版本](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3/examples/diffsynth)查看
|
| 781 |
+
|
| 782 |
+
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-4481-b79f-0c3a7361a1ea
|
| 783 |
+
|
| 784 |
+
</details>
|
assets/egg.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:96647e399f3a7772d32bee1562a0c01fe8b273f6ad5fe2708cc10878972fce04
|
| 3 |
+
size 4810228
|
comp_attn_bbox_layout.png
ADDED
|
comp_attn_trajectory.png
ADDED
|
diffsynth/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .core import *
|
diffsynth/configs/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .model_configs import MODEL_CONFIGS
|
| 2 |
+
from .vram_management_module_maps import VRAM_MANAGEMENT_MODULE_MAPS
|
diffsynth/configs/model_configs.py
ADDED
|
@@ -0,0 +1,518 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
qwen_image_series = [
|
| 2 |
+
{
|
| 3 |
+
# Example: ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors")
|
| 4 |
+
"model_hash": "0319a1cb19835fb510907dd3367c95ff",
|
| 5 |
+
"model_name": "qwen_image_dit",
|
| 6 |
+
"model_class": "diffsynth.models.qwen_image_dit.QwenImageDiT",
|
| 7 |
+
},
|
| 8 |
+
{
|
| 9 |
+
# Example: ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors")
|
| 10 |
+
"model_hash": "8004730443f55db63092006dd9f7110e",
|
| 11 |
+
"model_name": "qwen_image_text_encoder",
|
| 12 |
+
"model_class": "diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder",
|
| 13 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.qwen_image_text_encoder.QwenImageTextEncoderStateDictConverter",
|
| 14 |
+
},
|
| 15 |
+
{
|
| 16 |
+
# Example: ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors")
|
| 17 |
+
"model_hash": "ed4ea5824d55ec3107b09815e318123a",
|
| 18 |
+
"model_name": "qwen_image_vae",
|
| 19 |
+
"model_class": "diffsynth.models.qwen_image_vae.QwenImageVAE",
|
| 20 |
+
},
|
| 21 |
+
{
|
| 22 |
+
# Example: ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth", origin_file_pattern="model.safetensors")
|
| 23 |
+
"model_hash": "073bce9cf969e317e5662cd570c3e79c",
|
| 24 |
+
"model_name": "qwen_image_blockwise_controlnet",
|
| 25 |
+
"model_class": "diffsynth.models.qwen_image_controlnet.QwenImageBlockWiseControlNet",
|
| 26 |
+
},
|
| 27 |
+
{
|
| 28 |
+
# Example: ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint", origin_file_pattern="model.safetensors")
|
| 29 |
+
"model_hash": "a9e54e480a628f0b956a688a81c33bab",
|
| 30 |
+
"model_name": "qwen_image_blockwise_controlnet",
|
| 31 |
+
"model_class": "diffsynth.models.qwen_image_controlnet.QwenImageBlockWiseControlNet",
|
| 32 |
+
"extra_kwargs": {"additional_in_dim": 4},
|
| 33 |
+
},
|
| 34 |
+
{
|
| 35 |
+
# Example: ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="SigLIP2-G384/model.safetensors")
|
| 36 |
+
"model_hash": "469c78b61e3e31bc9eec0d0af3d3f2f8",
|
| 37 |
+
"model_name": "siglip2_image_encoder",
|
| 38 |
+
"model_class": "diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder",
|
| 39 |
+
},
|
| 40 |
+
{
|
| 41 |
+
# Example: ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="DINOv3-7B/model.safetensors")
|
| 42 |
+
"model_hash": "5722b5c873720009de96422993b15682",
|
| 43 |
+
"model_name": "dinov3_image_encoder",
|
| 44 |
+
"model_class": "diffsynth.models.dinov3_image_encoder.DINOv3ImageEncoder",
|
| 45 |
+
},
|
| 46 |
+
{
|
| 47 |
+
# Example:
|
| 48 |
+
"model_hash": "a166c33455cdbd89c0888a3645ca5c0f",
|
| 49 |
+
"model_name": "qwen_image_image2lora_coarse",
|
| 50 |
+
"model_class": "diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel",
|
| 51 |
+
},
|
| 52 |
+
{
|
| 53 |
+
# Example:
|
| 54 |
+
"model_hash": "a5476e691767a4da6d3a6634a10f7408",
|
| 55 |
+
"model_name": "qwen_image_image2lora_fine",
|
| 56 |
+
"model_class": "diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel",
|
| 57 |
+
"extra_kwargs": {"residual_length": 37*37+7, "residual_mid_dim": 64}
|
| 58 |
+
},
|
| 59 |
+
{
|
| 60 |
+
# Example:
|
| 61 |
+
"model_hash": "0aad514690602ecaff932c701cb4b0bb",
|
| 62 |
+
"model_name": "qwen_image_image2lora_style",
|
| 63 |
+
"model_class": "diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel",
|
| 64 |
+
"extra_kwargs": {"compress_dim": 64, "use_residual": False}
|
| 65 |
+
},
|
| 66 |
+
]
|
| 67 |
+
|
| 68 |
+
wan_series = [
|
| 69 |
+
{
|
| 70 |
+
# Example: ModelConfig(model_id="krea/krea-realtime-video", origin_file_pattern="krea-realtime-video-14b.safetensors")
|
| 71 |
+
"model_hash": "5ec04e02b42d2580483ad69f4e76346a",
|
| 72 |
+
"model_name": "wan_video_dit",
|
| 73 |
+
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
| 74 |
+
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06},
|
| 75 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTStateDictConverter",
|
| 76 |
+
},
|
| 77 |
+
{
|
| 78 |
+
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth")
|
| 79 |
+
"model_hash": "9c8818c2cbea55eca56c7b447df170da",
|
| 80 |
+
"model_name": "wan_video_text_encoder",
|
| 81 |
+
"model_class": "diffsynth.models.wan_video_text_encoder.WanTextEncoder",
|
| 82 |
+
},
|
| 83 |
+
{
|
| 84 |
+
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="Wan2.1_VAE.pth")
|
| 85 |
+
"model_hash": "ccc42284ea13e1ad04693284c7a09be6",
|
| 86 |
+
"model_name": "wan_video_vae",
|
| 87 |
+
"model_class": "diffsynth.models.wan_video_vae.WanVideoVAE",
|
| 88 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_vae.WanVideoVAEStateDictConverter",
|
| 89 |
+
},
|
| 90 |
+
{
|
| 91 |
+
# Example: ModelConfig(model_id="meituan-longcat/LongCat-Video", origin_file_pattern="dit/diffusion_pytorch_model*.safetensors")
|
| 92 |
+
"model_hash": "8b27900f680d7251ce44e2dc8ae1ffef",
|
| 93 |
+
"model_name": "wan_video_dit",
|
| 94 |
+
"model_class": "diffsynth.models.longcat_video_dit.LongCatVideoTransformer3DModel",
|
| 95 |
+
},
|
| 96 |
+
{
|
| 97 |
+
# Example: ModelConfig(model_id="ByteDance/Video-As-Prompt-Wan2.1-14B", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors")
|
| 98 |
+
"model_hash": "5f90e66a0672219f12d9a626c8c21f61",
|
| 99 |
+
"model_name": "wan_video_dit",
|
| 100 |
+
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
| 101 |
+
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06},
|
| 102 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTFromDiffusers"
|
| 103 |
+
},
|
| 104 |
+
{
|
| 105 |
+
# Example: ModelConfig(model_id="ByteDance/Video-As-Prompt-Wan2.1-14B", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors")
|
| 106 |
+
"model_hash": "5f90e66a0672219f12d9a626c8c21f61",
|
| 107 |
+
"model_name": "wan_video_vap",
|
| 108 |
+
"model_class": "diffsynth.models.wan_video_mot.MotWanModel",
|
| 109 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_mot.WanVideoMotStateDictConverter"
|
| 110 |
+
},
|
| 111 |
+
{
|
| 112 |
+
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth")
|
| 113 |
+
"model_hash": "5941c53e207d62f20f9025686193c40b",
|
| 114 |
+
"model_name": "wan_video_image_encoder",
|
| 115 |
+
"model_class": "diffsynth.models.wan_video_image_encoder.WanImageEncoder",
|
| 116 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_image_encoder.WanImageEncoderStateDictConverter"
|
| 117 |
+
},
|
| 118 |
+
{
|
| 119 |
+
# Example: ModelConfig(model_id="DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1", origin_file_pattern="model.safetensors")
|
| 120 |
+
"model_hash": "dbd5ec76bbf977983f972c151d545389",
|
| 121 |
+
"model_name": "wan_video_motion_controller",
|
| 122 |
+
"model_class": "diffsynth.models.wan_video_motion_controller.WanMotionControllerModel",
|
| 123 |
+
},
|
| 124 |
+
{
|
| 125 |
+
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
| 126 |
+
"model_hash": "9269f8db9040a9d860eaca435be61814",
|
| 127 |
+
"model_name": "wan_video_dit",
|
| 128 |
+
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
| 129 |
+
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06}
|
| 130 |
+
},
|
| 131 |
+
{
|
| 132 |
+
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-FLF2V-14B-720P", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
| 133 |
+
"model_hash": "3ef3b1f8e1dab83d5b71fd7b617f859f",
|
| 134 |
+
"model_name": "wan_video_dit",
|
| 135 |
+
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
| 136 |
+
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_image_pos_emb': True}
|
| 137 |
+
},
|
| 138 |
+
{
|
| 139 |
+
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
| 140 |
+
"model_hash": "349723183fc063b2bfc10bb2835cf677",
|
| 141 |
+
"model_name": "wan_video_dit",
|
| 142 |
+
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
| 143 |
+
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06}
|
| 144 |
+
},
|
| 145 |
+
{
|
| 146 |
+
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
| 147 |
+
"model_hash": "6d6ccde6845b95ad9114ab993d917893",
|
| 148 |
+
"model_name": "wan_video_dit",
|
| 149 |
+
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
| 150 |
+
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06}
|
| 151 |
+
},
|
| 152 |
+
{
|
| 153 |
+
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
| 154 |
+
"model_hash": "efa44cddf936c70abd0ea28b6cbe946c",
|
| 155 |
+
"model_name": "wan_video_dit",
|
| 156 |
+
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
| 157 |
+
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06}
|
| 158 |
+
},
|
| 159 |
+
{
|
| 160 |
+
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-14B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
| 161 |
+
"model_hash": "6bfcfb3b342cb286ce886889d519a77e",
|
| 162 |
+
"model_name": "wan_video_dit",
|
| 163 |
+
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
| 164 |
+
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06}
|
| 165 |
+
},
|
| 166 |
+
{
|
| 167 |
+
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
| 168 |
+
"model_hash": "ac6a5aa74f4a0aab6f64eb9a72f19901",
|
| 169 |
+
"model_name": "wan_video_dit",
|
| 170 |
+
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
| 171 |
+
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 32, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06, 'has_ref_conv': False, 'add_control_adapter': True, 'in_dim_control_adapter': 24}
|
| 172 |
+
},
|
| 173 |
+
{
|
| 174 |
+
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
| 175 |
+
"model_hash": "70ddad9d3a133785da5ea371aae09504",
|
| 176 |
+
"model_name": "wan_video_dit",
|
| 177 |
+
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
| 178 |
+
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06, 'has_ref_conv': True}
|
| 179 |
+
},
|
| 180 |
+
{
|
| 181 |
+
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control-Camera", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
| 182 |
+
"model_hash": "b61c605c2adbd23124d152ed28e049ae",
|
| 183 |
+
"model_name": "wan_video_dit",
|
| 184 |
+
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
| 185 |
+
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 32, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_ref_conv': False, 'add_control_adapter': True, 'in_dim_control_adapter': 24}
|
| 186 |
+
},
|
| 187 |
+
{
|
| 188 |
+
# Example: ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
| 189 |
+
"model_hash": "26bde73488a92e64cc20b0a7485b9e5b",
|
| 190 |
+
"model_name": "wan_video_dit",
|
| 191 |
+
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
| 192 |
+
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_ref_conv': True}
|
| 193 |
+
},
|
| 194 |
+
{
|
| 195 |
+
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
| 196 |
+
"model_hash": "aafcfd9672c3a2456dc46e1cb6e52c70",
|
| 197 |
+
"model_name": "wan_video_dit",
|
| 198 |
+
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
| 199 |
+
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06}
|
| 200 |
+
},
|
| 201 |
+
{
|
| 202 |
+
# Example: ModelConfig(model_id="iic/VACE-Wan2.1-1.3B-Preview", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
| 203 |
+
"model_hash": "a61453409b67cd3246cf0c3bebad47ba",
|
| 204 |
+
"model_name": "wan_video_dit",
|
| 205 |
+
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
| 206 |
+
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06},
|
| 207 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTStateDictConverter",
|
| 208 |
+
},
|
| 209 |
+
{
|
| 210 |
+
# Example: ModelConfig(model_id="iic/VACE-Wan2.1-1.3B-Preview", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
| 211 |
+
"model_hash": "a61453409b67cd3246cf0c3bebad47ba",
|
| 212 |
+
"model_name": "wan_video_vace",
|
| 213 |
+
"model_class": "diffsynth.models.wan_video_vace.VaceWanModel",
|
| 214 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_vace.VaceWanModelDictConverter"
|
| 215 |
+
},
|
| 216 |
+
{
|
| 217 |
+
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-VACE-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
| 218 |
+
"model_hash": "7a513e1f257a861512b1afd387a8ecd9",
|
| 219 |
+
"model_name": "wan_video_dit",
|
| 220 |
+
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
| 221 |
+
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06},
|
| 222 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTStateDictConverter",
|
| 223 |
+
},
|
| 224 |
+
{
|
| 225 |
+
# Example: ModelConfig(model_id="Wan-AI/Wan2.1-VACE-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
| 226 |
+
"model_hash": "7a513e1f257a861512b1afd387a8ecd9",
|
| 227 |
+
"model_name": "wan_video_vace",
|
| 228 |
+
"model_class": "diffsynth.models.wan_video_vace.VaceWanModel",
|
| 229 |
+
"extra_kwargs": {'vace_layers': (0, 5, 10, 15, 20, 25, 30, 35), 'vace_in_dim': 96, 'patch_size': (1, 2, 2), 'has_image_input': False, 'dim': 5120, 'num_heads': 40, 'ffn_dim': 13824, 'eps': 1e-06},
|
| 230 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_vace.VaceWanModelDictConverter"
|
| 231 |
+
},
|
| 232 |
+
{
|
| 233 |
+
# Example: ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
| 234 |
+
"model_hash": "31fa352acb8a1b1d33cd8764273d80a2",
|
| 235 |
+
"model_name": "wan_video_dit",
|
| 236 |
+
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
| 237 |
+
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06},
|
| 238 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTStateDictConverter"
|
| 239 |
+
},
|
| 240 |
+
{
|
| 241 |
+
# Example: ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
| 242 |
+
"model_hash": "31fa352acb8a1b1d33cd8764273d80a2",
|
| 243 |
+
"model_name": "wan_video_animate_adapter",
|
| 244 |
+
"model_class": "diffsynth.models.wan_video_animate_adapter.WanAnimateAdapter",
|
| 245 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_animate_adapter.WanAnimateAdapterStateDictConverter"
|
| 246 |
+
},
|
| 247 |
+
{
|
| 248 |
+
# Example: ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control-Camera", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors")
|
| 249 |
+
"model_hash": "47dbeab5e560db3180adf51dc0232fb1",
|
| 250 |
+
"model_name": "wan_video_dit",
|
| 251 |
+
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
| 252 |
+
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_ref_conv': False, 'add_control_adapter': True, 'in_dim_control_adapter': 24, 'require_clip_embedding': False}
|
| 253 |
+
},
|
| 254 |
+
{
|
| 255 |
+
# Example: ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors")
|
| 256 |
+
"model_hash": "2267d489f0ceb9f21836532952852ee5",
|
| 257 |
+
"model_name": "wan_video_dit",
|
| 258 |
+
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
| 259 |
+
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 52, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_ref_conv': True, 'require_clip_embedding': False},
|
| 260 |
+
},
|
| 261 |
+
{
|
| 262 |
+
# Example: ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors")
|
| 263 |
+
"model_hash": "5b013604280dd715f8457c6ed6d6a626",
|
| 264 |
+
"model_name": "wan_video_dit",
|
| 265 |
+
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
| 266 |
+
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'require_clip_embedding': False}
|
| 267 |
+
},
|
| 268 |
+
{
|
| 269 |
+
# Example: ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
| 270 |
+
"model_hash": "966cffdcc52f9c46c391768b27637614",
|
| 271 |
+
"model_name": "wan_video_dit",
|
| 272 |
+
"model_class": "diffsynth.models.wan_video_dit_s2v.WanS2VModel",
|
| 273 |
+
"extra_kwargs": {'dim': 5120, 'in_dim': 16, 'ffn_dim': 13824, 'out_dim': 16, 'text_dim': 4096, 'freq_dim': 256, 'eps': 1e-06, 'patch_size': (1, 2, 2), 'num_heads': 40, 'num_layers': 40, 'cond_dim': 16, 'audio_dim': 1024, 'num_audio_token': 4}
|
| 274 |
+
},
|
| 275 |
+
{
|
| 276 |
+
# Example: ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
| 277 |
+
"model_hash": "1f5ab7703c6fc803fdded85ff040c316",
|
| 278 |
+
"model_name": "wan_video_dit",
|
| 279 |
+
"model_class": "diffsynth.models.wan_video_dit.WanModel",
|
| 280 |
+
"extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 3072, 'ffn_dim': 14336, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 48, 'num_heads': 24, 'num_layers': 30, 'eps': 1e-06, 'seperated_timestep': True, 'require_clip_embedding': False, 'require_vae_embedding': False, 'fuse_vae_embedding_in_latents': True}
|
| 281 |
+
},
|
| 282 |
+
{
|
| 283 |
+
# Example: ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="Wan2.2_VAE.pth")
|
| 284 |
+
"model_hash": "e1de6c02cdac79f8b739f4d3698cd216",
|
| 285 |
+
"model_name": "wan_video_vae",
|
| 286 |
+
"model_class": "diffsynth.models.wan_video_vae.WanVideoVAE38",
|
| 287 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_vae.WanVideoVAEStateDictConverter",
|
| 288 |
+
},
|
| 289 |
+
{
|
| 290 |
+
# Example: ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="wav2vec2-large-xlsr-53-english/model.safetensors")
|
| 291 |
+
"model_hash": "06be60f3a4526586d8431cd038a71486",
|
| 292 |
+
"model_name": "wans2v_audio_encoder",
|
| 293 |
+
"model_class": "diffsynth.models.wav2vec.WanS2VAudioEncoder",
|
| 294 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.wans2v_audio_encoder.WanS2VAudioEncoderStateDictConverter",
|
| 295 |
+
},
|
| 296 |
+
]
|
| 297 |
+
|
| 298 |
+
flux_series = [
|
| 299 |
+
{
|
| 300 |
+
# Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors")
|
| 301 |
+
"model_hash": "a29710fea6dddb0314663ee823598e50",
|
| 302 |
+
"model_name": "flux_dit",
|
| 303 |
+
"model_class": "diffsynth.models.flux_dit.FluxDiT",
|
| 304 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
|
| 305 |
+
},
|
| 306 |
+
{
|
| 307 |
+
# Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors")
|
| 308 |
+
"model_hash": "94eefa3dac9cec93cb1ebaf1747d7b78",
|
| 309 |
+
"model_name": "flux_text_encoder_clip",
|
| 310 |
+
"model_class": "diffsynth.models.flux_text_encoder_clip.FluxTextEncoderClip",
|
| 311 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_text_encoder_clip.FluxTextEncoderClipStateDictConverter",
|
| 312 |
+
},
|
| 313 |
+
{
|
| 314 |
+
# Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors")
|
| 315 |
+
"model_hash": "22540b49eaedbc2f2784b2091a234c7c",
|
| 316 |
+
"model_name": "flux_text_encoder_t5",
|
| 317 |
+
"model_class": "diffsynth.models.flux_text_encoder_t5.FluxTextEncoderT5",
|
| 318 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_text_encoder_t5.FluxTextEncoderT5StateDictConverter",
|
| 319 |
+
},
|
| 320 |
+
{
|
| 321 |
+
# Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors")
|
| 322 |
+
"model_hash": "21ea55f476dfc4fd135587abb59dfe5d",
|
| 323 |
+
"model_name": "flux_vae_encoder",
|
| 324 |
+
"model_class": "diffsynth.models.flux_vae.FluxVAEEncoder",
|
| 325 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEEncoderStateDictConverter",
|
| 326 |
+
},
|
| 327 |
+
{
|
| 328 |
+
# Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors")
|
| 329 |
+
"model_hash": "21ea55f476dfc4fd135587abb59dfe5d",
|
| 330 |
+
"model_name": "flux_vae_decoder",
|
| 331 |
+
"model_class": "diffsynth.models.flux_vae.FluxVAEDecoder",
|
| 332 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEDecoderStateDictConverter",
|
| 333 |
+
},
|
| 334 |
+
{
|
| 335 |
+
# Example: ModelConfig(model_id="ostris/Flex.2-preview", origin_file_pattern="Flex.2-preview.safetensors")
|
| 336 |
+
"model_hash": "d02f41c13549fa5093d3521f62a5570a",
|
| 337 |
+
"model_name": "flux_dit",
|
| 338 |
+
"model_class": "diffsynth.models.flux_dit.FluxDiT",
|
| 339 |
+
"extra_kwargs": {'input_dim': 196, 'num_blocks': 8},
|
| 340 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
|
| 341 |
+
},
|
| 342 |
+
{
|
| 343 |
+
# Example: ModelConfig(model_id="DiffSynth-Studio/AttriCtrl-FLUX.1-Dev", origin_file_pattern="models/brightness.safetensors")
|
| 344 |
+
"model_hash": "0629116fce1472503a66992f96f3eb1a",
|
| 345 |
+
"model_name": "flux_value_controller",
|
| 346 |
+
"model_class": "diffsynth.models.flux_value_control.SingleValueEncoder",
|
| 347 |
+
},
|
| 348 |
+
{
|
| 349 |
+
# Example: ModelConfig(model_id="alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", origin_file_pattern="diffusion_pytorch_model.safetensors")
|
| 350 |
+
"model_hash": "52357cb26250681367488a8954c271e8",
|
| 351 |
+
"model_name": "flux_controlnet",
|
| 352 |
+
"model_class": "diffsynth.models.flux_controlnet.FluxControlNet",
|
| 353 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter",
|
| 354 |
+
"extra_kwargs": {"num_joint_blocks": 6, "num_single_blocks": 0, "additional_input_dim": 4},
|
| 355 |
+
},
|
| 356 |
+
{
|
| 357 |
+
# Example: ModelConfig(model_id="InstantX/FLUX.1-dev-Controlnet-Union-alpha", origin_file_pattern="diffusion_pytorch_model.safetensors")
|
| 358 |
+
"model_hash": "78d18b9101345ff695f312e7e62538c0",
|
| 359 |
+
"model_name": "flux_controlnet",
|
| 360 |
+
"model_class": "diffsynth.models.flux_controlnet.FluxControlNet",
|
| 361 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter",
|
| 362 |
+
"extra_kwargs": {"num_mode": 10, "mode_dict": {"canny": 0, "tile": 1, "depth": 2, "blur": 3, "pose": 4, "gray": 5, "lq": 6}},
|
| 363 |
+
},
|
| 364 |
+
{
|
| 365 |
+
# Example: ModelConfig(model_id="jasperai/Flux.1-dev-Controlnet-Upscaler", origin_file_pattern="diffusion_pytorch_model.safetensors")
|
| 366 |
+
"model_hash": "b001c89139b5f053c715fe772362dd2a",
|
| 367 |
+
"model_name": "flux_controlnet",
|
| 368 |
+
"model_class": "diffsynth.models.flux_controlnet.FluxControlNet",
|
| 369 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter",
|
| 370 |
+
"extra_kwargs": {"num_single_blocks": 0},
|
| 371 |
+
},
|
| 372 |
+
{
|
| 373 |
+
# Example: ModelConfig(model_id="ByteDance/InfiniteYou", origin_file_pattern="infu_flux_v1.0/aes_stage2/image_proj_model.bin")
|
| 374 |
+
"model_hash": "c07c0f04f5ff55e86b4e937c7a40d481",
|
| 375 |
+
"model_name": "infiniteyou_image_projector",
|
| 376 |
+
"model_class": "diffsynth.models.flux_infiniteyou.InfiniteYouImageProjector",
|
| 377 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_infiniteyou.FluxInfiniteYouImageProjectorStateDictConverter",
|
| 378 |
+
},
|
| 379 |
+
{
|
| 380 |
+
# Example: ModelConfig(model_id="ByteDance/InfiniteYou", origin_file_pattern="infu_flux_v1.0/aes_stage2/InfuseNetModel/*.safetensors")
|
| 381 |
+
"model_hash": "7f9583eb8ba86642abb9a21a4b2c9e16",
|
| 382 |
+
"model_name": "flux_controlnet",
|
| 383 |
+
"model_class": "diffsynth.models.flux_controlnet.FluxControlNet",
|
| 384 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter",
|
| 385 |
+
"extra_kwargs": {"num_joint_blocks": 4, "num_single_blocks": 10},
|
| 386 |
+
},
|
| 387 |
+
{
|
| 388 |
+
# Example: ModelConfig(model_id="DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev", origin_file_pattern="model.safetensors")
|
| 389 |
+
"model_hash": "77c2e4dd2440269eb33bfaa0d004f6ab",
|
| 390 |
+
"model_name": "flux_lora_encoder",
|
| 391 |
+
"model_class": "diffsynth.models.flux_lora_encoder.FluxLoRAEncoder",
|
| 392 |
+
},
|
| 393 |
+
{
|
| 394 |
+
# Example: ModelConfig(model_id="DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev", origin_file_pattern="model.safetensors")
|
| 395 |
+
"model_hash": "30143afb2dea73d1ac580e0787628f8c",
|
| 396 |
+
"model_name": "flux_lora_patcher",
|
| 397 |
+
"model_class": "diffsynth.models.flux_lora_patcher.FluxLoraPatcher",
|
| 398 |
+
},
|
| 399 |
+
{
|
| 400 |
+
# Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="model*.safetensors")
|
| 401 |
+
"model_hash": "2bd19e845116e4f875a0a048e27fc219",
|
| 402 |
+
"model_name": "nexus_gen_llm",
|
| 403 |
+
"model_class": "diffsynth.models.nexus_gen.NexusGenAutoregressiveModel",
|
| 404 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.nexus_gen.NexusGenAutoregressiveModelStateDictConverter",
|
| 405 |
+
},
|
| 406 |
+
{
|
| 407 |
+
# Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="edit_decoder.bin")
|
| 408 |
+
"model_hash": "63c969fd37cce769a90aa781fbff5f81",
|
| 409 |
+
"model_name": "nexus_gen_editing_adapter",
|
| 410 |
+
"model_class": "diffsynth.models.nexus_gen_projector.NexusGenImageEmbeddingMerger",
|
| 411 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.nexus_gen_projector.NexusGenMergerStateDictConverter",
|
| 412 |
+
},
|
| 413 |
+
{
|
| 414 |
+
# Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="edit_decoder.bin")
|
| 415 |
+
"model_hash": "63c969fd37cce769a90aa781fbff5f81",
|
| 416 |
+
"model_name": "flux_dit",
|
| 417 |
+
"model_class": "diffsynth.models.flux_dit.FluxDiT",
|
| 418 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
|
| 419 |
+
},
|
| 420 |
+
{
|
| 421 |
+
# Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="generation_decoder.bin")
|
| 422 |
+
"model_hash": "3e6c61b0f9471135fc9c6d6a98e98b6d",
|
| 423 |
+
"model_name": "nexus_gen_generation_adapter",
|
| 424 |
+
"model_class": "diffsynth.models.nexus_gen_projector.NexusGenAdapter",
|
| 425 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.nexus_gen_projector.NexusGenAdapterStateDictConverter",
|
| 426 |
+
},
|
| 427 |
+
{
|
| 428 |
+
# Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="generation_decoder.bin")
|
| 429 |
+
"model_hash": "3e6c61b0f9471135fc9c6d6a98e98b6d",
|
| 430 |
+
"model_name": "flux_dit",
|
| 431 |
+
"model_class": "diffsynth.models.flux_dit.FluxDiT",
|
| 432 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
|
| 433 |
+
},
|
| 434 |
+
{
|
| 435 |
+
# Example: ModelConfig(model_id="InstantX/FLUX.1-dev-IP-Adapter", origin_file_pattern="ip-adapter.bin")
|
| 436 |
+
"model_hash": "4daaa66cc656a8fe369908693dad0a35",
|
| 437 |
+
"model_name": "flux_ipadapter",
|
| 438 |
+
"model_class": "diffsynth.models.flux_ipadapter.FluxIpAdapter",
|
| 439 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_ipadapter.FluxIpAdapterStateDictConverter",
|
| 440 |
+
},
|
| 441 |
+
{
|
| 442 |
+
# Example: ModelConfig(model_id="google/siglip-so400m-patch14-384", origin_file_pattern="model.safetensors")
|
| 443 |
+
"model_hash": "04d8c1e20a1f1b25f7434f111992a33f",
|
| 444 |
+
"model_name": "siglip_vision_model",
|
| 445 |
+
"model_class": "diffsynth.models.flux_ipadapter.SiglipVisionModelSO400M",
|
| 446 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_ipadapter.SiglipStateDictConverter",
|
| 447 |
+
},
|
| 448 |
+
{
|
| 449 |
+
# Example: ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="step1x-edit-i1258.safetensors"),
|
| 450 |
+
"model_hash": "d30fb9e02b1dbf4e509142f05cf7dd50",
|
| 451 |
+
"model_name": "step1x_connector",
|
| 452 |
+
"model_class": "diffsynth.models.step1x_connector.Qwen2Connector",
|
| 453 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.step1x_connector.Qwen2ConnectorStateDictConverter",
|
| 454 |
+
},
|
| 455 |
+
{
|
| 456 |
+
# Example: ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="step1x-edit-i1258.safetensors"),
|
| 457 |
+
"model_hash": "d30fb9e02b1dbf4e509142f05cf7dd50",
|
| 458 |
+
"model_name": "flux_dit",
|
| 459 |
+
"model_class": "diffsynth.models.flux_dit.FluxDiT",
|
| 460 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
|
| 461 |
+
"extra_kwargs": {"disable_guidance_embedder": True},
|
| 462 |
+
},
|
| 463 |
+
]
|
| 464 |
+
|
| 465 |
+
flux2_series = [
|
| 466 |
+
{
|
| 467 |
+
# Example: ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="text_encoder/*.safetensors")
|
| 468 |
+
"model_hash": "28fca3d8e5bf2a2d1271748a773f6757",
|
| 469 |
+
"model_name": "flux2_text_encoder",
|
| 470 |
+
"model_class": "diffsynth.models.flux2_text_encoder.Flux2TextEncoder",
|
| 471 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux2_text_encoder.Flux2TextEncoderStateDictConverter",
|
| 472 |
+
},
|
| 473 |
+
{
|
| 474 |
+
# Example: ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="transformer/*.safetensors")
|
| 475 |
+
"model_hash": "d38e1d5c5aec3b0a11e79327ac6e3b0f",
|
| 476 |
+
"model_name": "flux2_dit",
|
| 477 |
+
"model_class": "diffsynth.models.flux2_dit.Flux2DiT",
|
| 478 |
+
},
|
| 479 |
+
{
|
| 480 |
+
# Example: ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="vae/diffusion_pytorch_model.safetensors")
|
| 481 |
+
"model_hash": "c54288e3ee12ca215898840682337b95",
|
| 482 |
+
"model_name": "flux2_vae",
|
| 483 |
+
"model_class": "diffsynth.models.flux2_vae.Flux2VAE",
|
| 484 |
+
},
|
| 485 |
+
]
|
| 486 |
+
|
| 487 |
+
z_image_series = [
|
| 488 |
+
{
|
| 489 |
+
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors")
|
| 490 |
+
"model_hash": "fc3a8a1247fe185ce116ccbe0e426c28",
|
| 491 |
+
"model_name": "z_image_dit",
|
| 492 |
+
"model_class": "diffsynth.models.z_image_dit.ZImageDiT",
|
| 493 |
+
},
|
| 494 |
+
{
|
| 495 |
+
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors")
|
| 496 |
+
"model_hash": "0f050f62a88876fea6eae0a18dac5a2e",
|
| 497 |
+
"model_name": "z_image_text_encoder",
|
| 498 |
+
"model_class": "diffsynth.models.z_image_text_encoder.ZImageTextEncoder",
|
| 499 |
+
},
|
| 500 |
+
{
|
| 501 |
+
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/vae/diffusion_pytorch_model.safetensors")
|
| 502 |
+
"model_hash": "1aafa3cc91716fb6b300cc1cd51b85a3",
|
| 503 |
+
"model_name": "flux_vae_encoder",
|
| 504 |
+
"model_class": "diffsynth.models.flux_vae.FluxVAEEncoder",
|
| 505 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEEncoderStateDictConverterDiffusers",
|
| 506 |
+
"extra_kwargs": {"use_conv_attention": False},
|
| 507 |
+
},
|
| 508 |
+
{
|
| 509 |
+
# Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/vae/diffusion_pytorch_model.safetensors")
|
| 510 |
+
"model_hash": "1aafa3cc91716fb6b300cc1cd51b85a3",
|
| 511 |
+
"model_name": "flux_vae_decoder",
|
| 512 |
+
"model_class": "diffsynth.models.flux_vae.FluxVAEDecoder",
|
| 513 |
+
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEDecoderStateDictConverterDiffusers",
|
| 514 |
+
"extra_kwargs": {"use_conv_attention": False},
|
| 515 |
+
},
|
| 516 |
+
]
|
| 517 |
+
|
| 518 |
+
MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series
|
diffsynth/configs/vram_management_module_maps.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
flux_general_vram_config = {
|
| 2 |
+
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
| 3 |
+
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 4 |
+
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 5 |
+
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 6 |
+
"torch.nn.GroupNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 7 |
+
"diffsynth.models.general_modules.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 8 |
+
"diffsynth.models.flux_lora_encoder.LoRALayerBlock": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 9 |
+
"diffsynth.models.flux_lora_patcher.LoraMerger": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 10 |
+
}
|
| 11 |
+
|
| 12 |
+
VRAM_MANAGEMENT_MODULE_MAPS = {
|
| 13 |
+
"diffsynth.models.qwen_image_dit.QwenImageDiT": {
|
| 14 |
+
"diffsynth.models.qwen_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 15 |
+
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
| 16 |
+
},
|
| 17 |
+
"diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder": {
|
| 18 |
+
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
| 19 |
+
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 20 |
+
"transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLRotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 21 |
+
"transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 22 |
+
"transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VisionPatchEmbed": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 23 |
+
"transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VisionRotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 24 |
+
},
|
| 25 |
+
"diffsynth.models.qwen_image_vae.QwenImageVAE": {
|
| 26 |
+
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
| 27 |
+
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 28 |
+
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 29 |
+
"diffsynth.models.qwen_image_vae.QwenImageRMS_norm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 30 |
+
},
|
| 31 |
+
"diffsynth.models.qwen_image_controlnet.BlockWiseControlBlock": {
|
| 32 |
+
"diffsynth.models.qwen_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 33 |
+
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
| 34 |
+
},
|
| 35 |
+
"diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder": {
|
| 36 |
+
"transformers.models.siglip.modeling_siglip.SiglipVisionEmbeddings": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 37 |
+
"transformers.models.siglip.modeling_siglip.SiglipMultiheadAttentionPoolingHead": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 38 |
+
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 39 |
+
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 40 |
+
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 41 |
+
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
| 42 |
+
},
|
| 43 |
+
"diffsynth.models.dinov3_image_encoder.DINOv3ImageEncoder": {
|
| 44 |
+
"transformers.models.dinov3_vit.modeling_dinov3_vit.DINOv3ViTLayerScale": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 45 |
+
"transformers.models.dinov3_vit.modeling_dinov3_vit.DINOv3ViTRopePositionEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 46 |
+
"transformers.models.dinov3_vit.modeling_dinov3_vit.DINOv3ViTEmbeddings": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 47 |
+
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 48 |
+
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 49 |
+
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
| 50 |
+
},
|
| 51 |
+
"diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel": {
|
| 52 |
+
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
| 53 |
+
},
|
| 54 |
+
"diffsynth.models.wan_video_animate_adapter.WanAnimateAdapter": {
|
| 55 |
+
"diffsynth.models.wan_video_animate_adapter.FaceEncoder": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 56 |
+
"diffsynth.models.wan_video_animate_adapter.EqualLinear": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 57 |
+
"diffsynth.models.wan_video_animate_adapter.ConvLayer": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 58 |
+
"diffsynth.models.wan_video_animate_adapter.FusedLeakyReLU": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 59 |
+
"diffsynth.models.wan_video_animate_adapter.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 60 |
+
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
| 61 |
+
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 62 |
+
"torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 63 |
+
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 64 |
+
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 65 |
+
},
|
| 66 |
+
"diffsynth.models.wan_video_dit_s2v.WanS2VModel": {
|
| 67 |
+
"diffsynth.models.wan_video_dit.Head": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 68 |
+
"diffsynth.models.wan_video_dit_s2v.WanS2VDiTBlock": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 69 |
+
"diffsynth.models.wan_video_dit_s2v.CausalAudioEncoder": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 70 |
+
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 71 |
+
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
| 72 |
+
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 73 |
+
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 74 |
+
"diffsynth.models.wan_video_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 75 |
+
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 76 |
+
},
|
| 77 |
+
"diffsynth.models.wan_video_dit.WanModel": {
|
| 78 |
+
"diffsynth.models.wan_video_dit.MLP": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 79 |
+
"diffsynth.models.wan_video_dit.DiTBlock": "diffsynth.core.vram.layers.AutoWrappedNonRecurseModule",
|
| 80 |
+
"diffsynth.models.wan_video_dit.Head": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 81 |
+
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
| 82 |
+
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 83 |
+
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 84 |
+
"diffsynth.models.wan_video_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 85 |
+
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 86 |
+
},
|
| 87 |
+
"diffsynth.models.wan_video_image_encoder.WanImageEncoder": {
|
| 88 |
+
"diffsynth.models.wan_video_image_encoder.VisionTransformer": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 89 |
+
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
| 90 |
+
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 91 |
+
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 92 |
+
},
|
| 93 |
+
"diffsynth.models.wan_video_mot.MotWanModel": {
|
| 94 |
+
"diffsynth.models.wan_video_mot.MotWanAttentionBlock": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 95 |
+
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 96 |
+
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
| 97 |
+
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 98 |
+
},
|
| 99 |
+
"diffsynth.models.wan_video_motion_controller.WanMotionControllerModel": {
|
| 100 |
+
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
| 101 |
+
},
|
| 102 |
+
"diffsynth.models.wan_video_text_encoder.WanTextEncoder": {
|
| 103 |
+
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
| 104 |
+
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 105 |
+
"diffsynth.models.wan_video_text_encoder.T5RelativeEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 106 |
+
"diffsynth.models.wan_video_text_encoder.T5LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 107 |
+
},
|
| 108 |
+
"diffsynth.models.wan_video_vace.VaceWanModel": {
|
| 109 |
+
"diffsynth.models.wan_video_dit.DiTBlock": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 110 |
+
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
| 111 |
+
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 112 |
+
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 113 |
+
"diffsynth.models.wan_video_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 114 |
+
},
|
| 115 |
+
"diffsynth.models.wan_video_vae.WanVideoVAE": {
|
| 116 |
+
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
| 117 |
+
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 118 |
+
"diffsynth.models.wan_video_vae.RMS_norm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 119 |
+
"diffsynth.models.wan_video_vae.CausalConv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 120 |
+
"diffsynth.models.wan_video_vae.Upsample": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 121 |
+
"torch.nn.SiLU": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 122 |
+
"torch.nn.Dropout": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 123 |
+
},
|
| 124 |
+
"diffsynth.models.wan_video_vae.WanVideoVAE38": {
|
| 125 |
+
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
| 126 |
+
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 127 |
+
"diffsynth.models.wan_video_vae.RMS_norm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 128 |
+
"diffsynth.models.wan_video_vae.CausalConv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 129 |
+
"diffsynth.models.wan_video_vae.Upsample": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 130 |
+
"torch.nn.SiLU": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 131 |
+
"torch.nn.Dropout": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 132 |
+
},
|
| 133 |
+
"diffsynth.models.wav2vec.WanS2VAudioEncoder": {
|
| 134 |
+
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
| 135 |
+
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 136 |
+
"torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 137 |
+
},
|
| 138 |
+
"diffsynth.models.longcat_video_dit.LongCatVideoTransformer3DModel": {
|
| 139 |
+
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
| 140 |
+
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 141 |
+
"diffsynth.models.longcat_video_dit.RMSNorm_FP32": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 142 |
+
"diffsynth.models.longcat_video_dit.LayerNorm_FP32": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 143 |
+
},
|
| 144 |
+
"diffsynth.models.flux_dit.FluxDiT": {
|
| 145 |
+
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
| 146 |
+
"diffsynth.models.flux_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 147 |
+
},
|
| 148 |
+
"diffsynth.models.flux_text_encoder_clip.FluxTextEncoderClip": flux_general_vram_config,
|
| 149 |
+
"diffsynth.models.flux_vae.FluxVAEEncoder": flux_general_vram_config,
|
| 150 |
+
"diffsynth.models.flux_vae.FluxVAEDecoder": flux_general_vram_config,
|
| 151 |
+
"diffsynth.models.flux_controlnet.FluxControlNet": flux_general_vram_config,
|
| 152 |
+
"diffsynth.models.flux_infiniteyou.InfiniteYouImageProjector": flux_general_vram_config,
|
| 153 |
+
"diffsynth.models.flux_ipadapter.FluxIpAdapter": flux_general_vram_config,
|
| 154 |
+
"diffsynth.models.flux_lora_patcher.FluxLoraPatcher": flux_general_vram_config,
|
| 155 |
+
"diffsynth.models.step1x_connector.Qwen2Connector": flux_general_vram_config,
|
| 156 |
+
"diffsynth.models.flux_lora_encoder.FluxLoRAEncoder": flux_general_vram_config,
|
| 157 |
+
"diffsynth.models.flux_text_encoder_t5.FluxTextEncoderT5": {
|
| 158 |
+
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
| 159 |
+
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 160 |
+
"transformers.models.t5.modeling_t5.T5LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 161 |
+
"transformers.models.t5.modeling_t5.T5DenseActDense": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 162 |
+
"transformers.models.t5.modeling_t5.T5DenseGatedActDense": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 163 |
+
},
|
| 164 |
+
"diffsynth.models.flux_ipadapter.SiglipVisionModelSO400M": {
|
| 165 |
+
"transformers.models.siglip.modeling_siglip.SiglipVisionEmbeddings": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 166 |
+
"transformers.models.siglip.modeling_siglip.SiglipEncoder": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 167 |
+
"transformers.models.siglip.modeling_siglip.SiglipMultiheadAttentionPoolingHead": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 168 |
+
"torch.nn.MultiheadAttention": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 169 |
+
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
| 170 |
+
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 171 |
+
},
|
| 172 |
+
"diffsynth.models.flux2_dit.Flux2DiT": {
|
| 173 |
+
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
| 174 |
+
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 175 |
+
"torch.nn.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 176 |
+
},
|
| 177 |
+
"diffsynth.models.flux2_text_encoder.Flux2TextEncoder": {
|
| 178 |
+
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
| 179 |
+
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 180 |
+
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 181 |
+
"transformers.models.mistral.modeling_mistral.MistralRMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 182 |
+
},
|
| 183 |
+
"diffsynth.models.flux2_vae.Flux2VAE": {
|
| 184 |
+
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
| 185 |
+
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 186 |
+
"torch.nn.GroupNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 187 |
+
},
|
| 188 |
+
"diffsynth.models.z_image_text_encoder.ZImageTextEncoder": {
|
| 189 |
+
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
| 190 |
+
"transformers.models.qwen3.modeling_qwen3.Qwen3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 191 |
+
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 192 |
+
},
|
| 193 |
+
"diffsynth.models.z_image_dit.ZImageDiT": {
|
| 194 |
+
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
| 195 |
+
"diffsynth.models.z_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
| 196 |
+
},
|
| 197 |
+
}
|
diffsynth/core/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .attention import *
|
| 2 |
+
from .data import *
|
| 3 |
+
from .gradient import *
|
| 4 |
+
from .loader import *
|
| 5 |
+
from .vram import *
|
diffsynth/core/attention/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .attention import attention_forward
|
diffsynth/core/attention/attention.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch, os
|
| 2 |
+
from einops import rearrange
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
try:
|
| 6 |
+
import flash_attn_interface
|
| 7 |
+
FLASH_ATTN_3_AVAILABLE = True
|
| 8 |
+
except ModuleNotFoundError:
|
| 9 |
+
FLASH_ATTN_3_AVAILABLE = False
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
import flash_attn
|
| 13 |
+
FLASH_ATTN_2_AVAILABLE = True
|
| 14 |
+
except ModuleNotFoundError:
|
| 15 |
+
FLASH_ATTN_2_AVAILABLE = False
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
from sageattention import sageattn
|
| 19 |
+
SAGE_ATTN_AVAILABLE = True
|
| 20 |
+
except ModuleNotFoundError:
|
| 21 |
+
SAGE_ATTN_AVAILABLE = False
|
| 22 |
+
|
| 23 |
+
try:
|
| 24 |
+
import xformers.ops as xops
|
| 25 |
+
XFORMERS_AVAILABLE = True
|
| 26 |
+
except ModuleNotFoundError:
|
| 27 |
+
XFORMERS_AVAILABLE = False
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def initialize_attention_priority():
|
| 31 |
+
if os.environ.get('DIFFSYNTH_ATTENTION_IMPLEMENTATION') is not None:
|
| 32 |
+
return os.environ.get('DIFFSYNTH_ATTENTION_IMPLEMENTATION').lower()
|
| 33 |
+
elif FLASH_ATTN_3_AVAILABLE:
|
| 34 |
+
return "flash_attention_3"
|
| 35 |
+
elif FLASH_ATTN_2_AVAILABLE:
|
| 36 |
+
return "flash_attention_2"
|
| 37 |
+
elif SAGE_ATTN_AVAILABLE:
|
| 38 |
+
return "sage_attention"
|
| 39 |
+
elif XFORMERS_AVAILABLE:
|
| 40 |
+
return "xformers"
|
| 41 |
+
else:
|
| 42 |
+
return "torch"
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
ATTENTION_IMPLEMENTATION = initialize_attention_priority()
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def rearrange_qkv(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", required_in_pattern="b n s d", dims=None):
|
| 49 |
+
dims = {} if dims is None else dims
|
| 50 |
+
if q_pattern != required_in_pattern:
|
| 51 |
+
q = rearrange(q, f"{q_pattern} -> {required_in_pattern}", **dims)
|
| 52 |
+
if k_pattern != required_in_pattern:
|
| 53 |
+
k = rearrange(k, f"{k_pattern} -> {required_in_pattern}", **dims)
|
| 54 |
+
if v_pattern != required_in_pattern:
|
| 55 |
+
v = rearrange(v, f"{q_pattern} -> {required_in_pattern}", **dims)
|
| 56 |
+
return q, k, v
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def rearrange_out(out: torch.Tensor, out_pattern="b n s d", required_out_pattern="b n s d", dims=None):
|
| 60 |
+
dims = {} if dims is None else dims
|
| 61 |
+
if out_pattern != required_out_pattern:
|
| 62 |
+
out = rearrange(out, f"{required_out_pattern} -> {out_pattern}", **dims)
|
| 63 |
+
return out
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def torch_sdpa(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, attn_mask=None, scale=None):
|
| 67 |
+
required_in_pattern, required_out_pattern= "b n s d", "b n s d"
|
| 68 |
+
q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
|
| 69 |
+
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask, scale=scale)
|
| 70 |
+
out = rearrange_out(out, out_pattern, required_out_pattern, dims)
|
| 71 |
+
return out
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def flash_attention_3(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None):
|
| 75 |
+
required_in_pattern, required_out_pattern= "b s n d", "b s n d"
|
| 76 |
+
q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
|
| 77 |
+
out = flash_attn_interface.flash_attn_func(q, k, v, softmax_scale=scale)
|
| 78 |
+
if isinstance(out, tuple):
|
| 79 |
+
out = out[0]
|
| 80 |
+
out = rearrange_out(out, out_pattern, required_out_pattern, dims)
|
| 81 |
+
return out
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def flash_attention_2(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None):
|
| 85 |
+
required_in_pattern, required_out_pattern= "b s n d", "b s n d"
|
| 86 |
+
q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
|
| 87 |
+
out = flash_attn.flash_attn_func(q, k, v, softmax_scale=scale)
|
| 88 |
+
out = rearrange_out(out, out_pattern, required_out_pattern, dims)
|
| 89 |
+
return out
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def sage_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None):
|
| 93 |
+
required_in_pattern, required_out_pattern= "b n s d", "b n s d"
|
| 94 |
+
q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
|
| 95 |
+
out = sageattn(q, k, v, sm_scale=scale)
|
| 96 |
+
out = rearrange_out(out, out_pattern, required_out_pattern, dims)
|
| 97 |
+
return out
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def xformers_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None):
|
| 101 |
+
required_in_pattern, required_out_pattern= "b s n d", "b s n d"
|
| 102 |
+
q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
|
| 103 |
+
out = xops.memory_efficient_attention(q, k, v, scale=scale)
|
| 104 |
+
out = rearrange_out(out, out_pattern, required_out_pattern, dims)
|
| 105 |
+
return out
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def attention_forward(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, attn_mask=None, scale=None, compatibility_mode=False):
|
| 109 |
+
if compatibility_mode or (attn_mask is not None):
|
| 110 |
+
return torch_sdpa(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, attn_mask=attn_mask, scale=scale)
|
| 111 |
+
else:
|
| 112 |
+
if ATTENTION_IMPLEMENTATION == "flash_attention_3":
|
| 113 |
+
return flash_attention_3(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
|
| 114 |
+
elif ATTENTION_IMPLEMENTATION == "flash_attention_2":
|
| 115 |
+
return flash_attention_2(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
|
| 116 |
+
elif ATTENTION_IMPLEMENTATION == "sage_attention":
|
| 117 |
+
return sage_attention(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
|
| 118 |
+
elif ATTENTION_IMPLEMENTATION == "xformers":
|
| 119 |
+
return xformers_attention(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
|
| 120 |
+
else:
|
| 121 |
+
return torch_sdpa(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
|
diffsynth/core/data/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .unified_dataset import UnifiedDataset
|
diffsynth/core/data/operators.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch, torchvision, imageio, os
|
| 2 |
+
import imageio.v3 as iio
|
| 3 |
+
from PIL import Image
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class DataProcessingPipeline:
|
| 7 |
+
def __init__(self, operators=None):
|
| 8 |
+
self.operators: list[DataProcessingOperator] = [] if operators is None else operators
|
| 9 |
+
|
| 10 |
+
def __call__(self, data):
|
| 11 |
+
for operator in self.operators:
|
| 12 |
+
data = operator(data)
|
| 13 |
+
return data
|
| 14 |
+
|
| 15 |
+
def __rshift__(self, pipe):
|
| 16 |
+
if isinstance(pipe, DataProcessingOperator):
|
| 17 |
+
pipe = DataProcessingPipeline([pipe])
|
| 18 |
+
return DataProcessingPipeline(self.operators + pipe.operators)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class DataProcessingOperator:
|
| 22 |
+
def __call__(self, data):
|
| 23 |
+
raise NotImplementedError("DataProcessingOperator cannot be called directly.")
|
| 24 |
+
|
| 25 |
+
def __rshift__(self, pipe):
|
| 26 |
+
if isinstance(pipe, DataProcessingOperator):
|
| 27 |
+
pipe = DataProcessingPipeline([pipe])
|
| 28 |
+
return DataProcessingPipeline([self]).__rshift__(pipe)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class DataProcessingOperatorRaw(DataProcessingOperator):
|
| 32 |
+
def __call__(self, data):
|
| 33 |
+
return data
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class ToInt(DataProcessingOperator):
|
| 37 |
+
def __call__(self, data):
|
| 38 |
+
return int(data)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class ToFloat(DataProcessingOperator):
|
| 42 |
+
def __call__(self, data):
|
| 43 |
+
return float(data)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class ToStr(DataProcessingOperator):
|
| 47 |
+
def __init__(self, none_value=""):
|
| 48 |
+
self.none_value = none_value
|
| 49 |
+
|
| 50 |
+
def __call__(self, data):
|
| 51 |
+
if data is None: data = self.none_value
|
| 52 |
+
return str(data)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class LoadImage(DataProcessingOperator):
|
| 56 |
+
def __init__(self, convert_RGB=True):
|
| 57 |
+
self.convert_RGB = convert_RGB
|
| 58 |
+
|
| 59 |
+
def __call__(self, data: str):
|
| 60 |
+
image = Image.open(data)
|
| 61 |
+
if self.convert_RGB: image = image.convert("RGB")
|
| 62 |
+
return image
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class ImageCropAndResize(DataProcessingOperator):
|
| 66 |
+
def __init__(self, height=None, width=None, max_pixels=None, height_division_factor=1, width_division_factor=1):
|
| 67 |
+
self.height = height
|
| 68 |
+
self.width = width
|
| 69 |
+
self.max_pixels = max_pixels
|
| 70 |
+
self.height_division_factor = height_division_factor
|
| 71 |
+
self.width_division_factor = width_division_factor
|
| 72 |
+
|
| 73 |
+
def crop_and_resize(self, image, target_height, target_width):
|
| 74 |
+
width, height = image.size
|
| 75 |
+
scale = max(target_width / width, target_height / height)
|
| 76 |
+
image = torchvision.transforms.functional.resize(
|
| 77 |
+
image,
|
| 78 |
+
(round(height*scale), round(width*scale)),
|
| 79 |
+
interpolation=torchvision.transforms.InterpolationMode.BILINEAR
|
| 80 |
+
)
|
| 81 |
+
image = torchvision.transforms.functional.center_crop(image, (target_height, target_width))
|
| 82 |
+
return image
|
| 83 |
+
|
| 84 |
+
def get_height_width(self, image):
|
| 85 |
+
if self.height is None or self.width is None:
|
| 86 |
+
width, height = image.size
|
| 87 |
+
if width * height > self.max_pixels:
|
| 88 |
+
scale = (width * height / self.max_pixels) ** 0.5
|
| 89 |
+
height, width = int(height / scale), int(width / scale)
|
| 90 |
+
height = height // self.height_division_factor * self.height_division_factor
|
| 91 |
+
width = width // self.width_division_factor * self.width_division_factor
|
| 92 |
+
else:
|
| 93 |
+
height, width = self.height, self.width
|
| 94 |
+
return height, width
|
| 95 |
+
|
| 96 |
+
def __call__(self, data: Image.Image):
|
| 97 |
+
image = self.crop_and_resize(data, *self.get_height_width(data))
|
| 98 |
+
return image
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class ToList(DataProcessingOperator):
|
| 102 |
+
def __call__(self, data):
|
| 103 |
+
return [data]
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class LoadVideo(DataProcessingOperator):
|
| 107 |
+
def __init__(self, num_frames=81, time_division_factor=4, time_division_remainder=1, frame_processor=lambda x: x):
|
| 108 |
+
self.num_frames = num_frames
|
| 109 |
+
self.time_division_factor = time_division_factor
|
| 110 |
+
self.time_division_remainder = time_division_remainder
|
| 111 |
+
# frame_processor is build in the video loader for high efficiency.
|
| 112 |
+
self.frame_processor = frame_processor
|
| 113 |
+
|
| 114 |
+
def get_num_frames(self, reader):
|
| 115 |
+
num_frames = self.num_frames
|
| 116 |
+
if int(reader.count_frames()) < num_frames:
|
| 117 |
+
num_frames = int(reader.count_frames())
|
| 118 |
+
while num_frames > 1 and num_frames % self.time_division_factor != self.time_division_remainder:
|
| 119 |
+
num_frames -= 1
|
| 120 |
+
return num_frames
|
| 121 |
+
|
| 122 |
+
def __call__(self, data: str):
|
| 123 |
+
reader = imageio.get_reader(data)
|
| 124 |
+
num_frames = self.get_num_frames(reader)
|
| 125 |
+
frames = []
|
| 126 |
+
for frame_id in range(num_frames):
|
| 127 |
+
frame = reader.get_data(frame_id)
|
| 128 |
+
frame = Image.fromarray(frame)
|
| 129 |
+
frame = self.frame_processor(frame)
|
| 130 |
+
frames.append(frame)
|
| 131 |
+
reader.close()
|
| 132 |
+
return frames
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class SequencialProcess(DataProcessingOperator):
|
| 136 |
+
def __init__(self, operator=lambda x: x):
|
| 137 |
+
self.operator = operator
|
| 138 |
+
|
| 139 |
+
def __call__(self, data):
|
| 140 |
+
return [self.operator(i) for i in data]
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
class LoadGIF(DataProcessingOperator):
|
| 144 |
+
def __init__(self, num_frames=81, time_division_factor=4, time_division_remainder=1, frame_processor=lambda x: x):
|
| 145 |
+
self.num_frames = num_frames
|
| 146 |
+
self.time_division_factor = time_division_factor
|
| 147 |
+
self.time_division_remainder = time_division_remainder
|
| 148 |
+
# frame_processor is build in the video loader for high efficiency.
|
| 149 |
+
self.frame_processor = frame_processor
|
| 150 |
+
|
| 151 |
+
def get_num_frames(self, path):
|
| 152 |
+
num_frames = self.num_frames
|
| 153 |
+
images = iio.imread(path, mode="RGB")
|
| 154 |
+
if len(images) < num_frames:
|
| 155 |
+
num_frames = len(images)
|
| 156 |
+
while num_frames > 1 and num_frames % self.time_division_factor != self.time_division_remainder:
|
| 157 |
+
num_frames -= 1
|
| 158 |
+
return num_frames
|
| 159 |
+
|
| 160 |
+
def __call__(self, data: str):
|
| 161 |
+
num_frames = self.get_num_frames(data)
|
| 162 |
+
frames = []
|
| 163 |
+
images = iio.imread(data, mode="RGB")
|
| 164 |
+
for img in images:
|
| 165 |
+
frame = Image.fromarray(img)
|
| 166 |
+
frame = self.frame_processor(frame)
|
| 167 |
+
frames.append(frame)
|
| 168 |
+
if len(frames) >= num_frames:
|
| 169 |
+
break
|
| 170 |
+
return frames
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
class RouteByExtensionName(DataProcessingOperator):
|
| 174 |
+
def __init__(self, operator_map):
|
| 175 |
+
self.operator_map = operator_map
|
| 176 |
+
|
| 177 |
+
def __call__(self, data: str):
|
| 178 |
+
file_ext_name = data.split(".")[-1].lower()
|
| 179 |
+
for ext_names, operator in self.operator_map:
|
| 180 |
+
if ext_names is None or file_ext_name in ext_names:
|
| 181 |
+
return operator(data)
|
| 182 |
+
raise ValueError(f"Unsupported file: {data}")
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
class RouteByType(DataProcessingOperator):
|
| 186 |
+
def __init__(self, operator_map):
|
| 187 |
+
self.operator_map = operator_map
|
| 188 |
+
|
| 189 |
+
def __call__(self, data):
|
| 190 |
+
for dtype, operator in self.operator_map:
|
| 191 |
+
if dtype is None or isinstance(data, dtype):
|
| 192 |
+
return operator(data)
|
| 193 |
+
raise ValueError(f"Unsupported data: {data}")
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
class LoadTorchPickle(DataProcessingOperator):
|
| 197 |
+
def __init__(self, map_location="cpu"):
|
| 198 |
+
self.map_location = map_location
|
| 199 |
+
|
| 200 |
+
def __call__(self, data):
|
| 201 |
+
return torch.load(data, map_location=self.map_location, weights_only=False)
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
class ToAbsolutePath(DataProcessingOperator):
|
| 205 |
+
def __init__(self, base_path=""):
|
| 206 |
+
self.base_path = base_path
|
| 207 |
+
|
| 208 |
+
def __call__(self, data):
|
| 209 |
+
return os.path.join(self.base_path, data)
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
class LoadAudio(DataProcessingOperator):
|
| 213 |
+
def __init__(self, sr=16000):
|
| 214 |
+
self.sr = sr
|
| 215 |
+
def __call__(self, data: str):
|
| 216 |
+
import librosa
|
| 217 |
+
input_audio, sample_rate = librosa.load(data, sr=self.sr)
|
| 218 |
+
return input_audio
|
diffsynth/core/data/unified_dataset.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .operators import *
|
| 2 |
+
import torch, json, pandas
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class UnifiedDataset(torch.utils.data.Dataset):
|
| 6 |
+
def __init__(
|
| 7 |
+
self,
|
| 8 |
+
base_path=None, metadata_path=None,
|
| 9 |
+
repeat=1,
|
| 10 |
+
data_file_keys=tuple(),
|
| 11 |
+
main_data_operator=lambda x: x,
|
| 12 |
+
special_operator_map=None,
|
| 13 |
+
):
|
| 14 |
+
self.base_path = base_path
|
| 15 |
+
self.metadata_path = metadata_path
|
| 16 |
+
self.repeat = repeat
|
| 17 |
+
self.data_file_keys = data_file_keys
|
| 18 |
+
self.main_data_operator = main_data_operator
|
| 19 |
+
self.cached_data_operator = LoadTorchPickle()
|
| 20 |
+
self.special_operator_map = {} if special_operator_map is None else special_operator_map
|
| 21 |
+
self.data = []
|
| 22 |
+
self.cached_data = []
|
| 23 |
+
self.load_from_cache = metadata_path is None
|
| 24 |
+
self.load_metadata(metadata_path)
|
| 25 |
+
|
| 26 |
+
@staticmethod
|
| 27 |
+
def default_image_operator(
|
| 28 |
+
base_path="",
|
| 29 |
+
max_pixels=1920*1080, height=None, width=None,
|
| 30 |
+
height_division_factor=16, width_division_factor=16,
|
| 31 |
+
):
|
| 32 |
+
return RouteByType(operator_map=[
|
| 33 |
+
(str, ToAbsolutePath(base_path) >> LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor)),
|
| 34 |
+
(list, SequencialProcess(ToAbsolutePath(base_path) >> LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor))),
|
| 35 |
+
])
|
| 36 |
+
|
| 37 |
+
@staticmethod
|
| 38 |
+
def default_video_operator(
|
| 39 |
+
base_path="",
|
| 40 |
+
max_pixels=1920*1080, height=None, width=None,
|
| 41 |
+
height_division_factor=16, width_division_factor=16,
|
| 42 |
+
num_frames=81, time_division_factor=4, time_division_remainder=1,
|
| 43 |
+
):
|
| 44 |
+
return RouteByType(operator_map=[
|
| 45 |
+
(str, ToAbsolutePath(base_path) >> RouteByExtensionName(operator_map=[
|
| 46 |
+
(("jpg", "jpeg", "png", "webp"), LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor) >> ToList()),
|
| 47 |
+
(("gif",), LoadGIF(
|
| 48 |
+
num_frames, time_division_factor, time_division_remainder,
|
| 49 |
+
frame_processor=ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor),
|
| 50 |
+
)),
|
| 51 |
+
(("mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"), LoadVideo(
|
| 52 |
+
num_frames, time_division_factor, time_division_remainder,
|
| 53 |
+
frame_processor=ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor),
|
| 54 |
+
)),
|
| 55 |
+
])),
|
| 56 |
+
])
|
| 57 |
+
|
| 58 |
+
def search_for_cached_data_files(self, path):
|
| 59 |
+
for file_name in os.listdir(path):
|
| 60 |
+
subpath = os.path.join(path, file_name)
|
| 61 |
+
if os.path.isdir(subpath):
|
| 62 |
+
self.search_for_cached_data_files(subpath)
|
| 63 |
+
elif subpath.endswith(".pth"):
|
| 64 |
+
self.cached_data.append(subpath)
|
| 65 |
+
|
| 66 |
+
def load_metadata(self, metadata_path):
|
| 67 |
+
if metadata_path is None:
|
| 68 |
+
print("No metadata_path. Searching for cached data files.")
|
| 69 |
+
self.search_for_cached_data_files(self.base_path)
|
| 70 |
+
print(f"{len(self.cached_data)} cached data files found.")
|
| 71 |
+
elif metadata_path.endswith(".json"):
|
| 72 |
+
with open(metadata_path, "r") as f:
|
| 73 |
+
metadata = json.load(f)
|
| 74 |
+
self.data = metadata
|
| 75 |
+
elif metadata_path.endswith(".jsonl"):
|
| 76 |
+
metadata = []
|
| 77 |
+
with open(metadata_path, 'r') as f:
|
| 78 |
+
for line in f:
|
| 79 |
+
metadata.append(json.loads(line.strip()))
|
| 80 |
+
self.data = metadata
|
| 81 |
+
else:
|
| 82 |
+
metadata = pandas.read_csv(metadata_path)
|
| 83 |
+
self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))]
|
| 84 |
+
|
| 85 |
+
def __getitem__(self, data_id):
|
| 86 |
+
if self.load_from_cache:
|
| 87 |
+
data = self.cached_data[data_id % len(self.cached_data)]
|
| 88 |
+
data = self.cached_data_operator(data)
|
| 89 |
+
else:
|
| 90 |
+
data = self.data[data_id % len(self.data)].copy()
|
| 91 |
+
for key in self.data_file_keys:
|
| 92 |
+
if key in data:
|
| 93 |
+
if key in self.special_operator_map:
|
| 94 |
+
data[key] = self.special_operator_map[key](data[key])
|
| 95 |
+
elif key in self.data_file_keys:
|
| 96 |
+
data[key] = self.main_data_operator(data[key])
|
| 97 |
+
return data
|
| 98 |
+
|
| 99 |
+
def __len__(self):
|
| 100 |
+
if self.load_from_cache:
|
| 101 |
+
return len(self.cached_data) * self.repeat
|
| 102 |
+
else:
|
| 103 |
+
return len(self.data) * self.repeat
|
| 104 |
+
|
| 105 |
+
def check_data_equal(self, data1, data2):
|
| 106 |
+
# Debug only
|
| 107 |
+
if len(data1) != len(data2):
|
| 108 |
+
return False
|
| 109 |
+
for k in data1:
|
| 110 |
+
if data1[k] != data2[k]:
|
| 111 |
+
return False
|
| 112 |
+
return True
|
diffsynth/core/gradient/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .gradient_checkpoint import gradient_checkpoint_forward
|
diffsynth/core/gradient/gradient_checkpoint.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def create_custom_forward(module):
|
| 5 |
+
def custom_forward(*inputs, **kwargs):
|
| 6 |
+
return module(*inputs, **kwargs)
|
| 7 |
+
return custom_forward
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def gradient_checkpoint_forward(
|
| 11 |
+
model,
|
| 12 |
+
use_gradient_checkpointing,
|
| 13 |
+
use_gradient_checkpointing_offload,
|
| 14 |
+
*args,
|
| 15 |
+
**kwargs,
|
| 16 |
+
):
|
| 17 |
+
if use_gradient_checkpointing_offload:
|
| 18 |
+
with torch.autograd.graph.save_on_cpu():
|
| 19 |
+
model_output = torch.utils.checkpoint.checkpoint(
|
| 20 |
+
create_custom_forward(model),
|
| 21 |
+
*args,
|
| 22 |
+
**kwargs,
|
| 23 |
+
use_reentrant=False,
|
| 24 |
+
)
|
| 25 |
+
elif use_gradient_checkpointing:
|
| 26 |
+
model_output = torch.utils.checkpoint.checkpoint(
|
| 27 |
+
create_custom_forward(model),
|
| 28 |
+
*args,
|
| 29 |
+
**kwargs,
|
| 30 |
+
use_reentrant=False,
|
| 31 |
+
)
|
| 32 |
+
else:
|
| 33 |
+
model_output = model(*args, **kwargs)
|
| 34 |
+
return model_output
|
diffsynth/core/loader/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .file import load_state_dict, hash_state_dict_keys, hash_model_file
|
| 2 |
+
from .model import load_model, load_model_with_disk_offload
|
| 3 |
+
from .config import ModelConfig
|
diffsynth/core/loader/config.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch, glob, os
|
| 2 |
+
from typing import Optional, Union
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from modelscope import snapshot_download
|
| 5 |
+
from huggingface_hub import snapshot_download as hf_snapshot_download
|
| 6 |
+
from typing import Optional
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@dataclass
|
| 10 |
+
class ModelConfig:
|
| 11 |
+
path: Union[str, list[str]] = None
|
| 12 |
+
model_id: str = None
|
| 13 |
+
origin_file_pattern: Union[str, list[str]] = None
|
| 14 |
+
download_source: str = None
|
| 15 |
+
local_model_path: str = None
|
| 16 |
+
skip_download: bool = None
|
| 17 |
+
offload_device: Optional[Union[str, torch.device]] = None
|
| 18 |
+
offload_dtype: Optional[torch.dtype] = None
|
| 19 |
+
onload_device: Optional[Union[str, torch.device]] = None
|
| 20 |
+
onload_dtype: Optional[torch.dtype] = None
|
| 21 |
+
preparing_device: Optional[Union[str, torch.device]] = None
|
| 22 |
+
preparing_dtype: Optional[torch.dtype] = None
|
| 23 |
+
computation_device: Optional[Union[str, torch.device]] = None
|
| 24 |
+
computation_dtype: Optional[torch.dtype] = None
|
| 25 |
+
clear_parameters: bool = False
|
| 26 |
+
|
| 27 |
+
def check_input(self):
|
| 28 |
+
if self.path is None and self.model_id is None:
|
| 29 |
+
raise ValueError(f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`. `skip_download=True` only supports the first one.""")
|
| 30 |
+
|
| 31 |
+
def parse_original_file_pattern(self):
|
| 32 |
+
if self.origin_file_pattern is None or self.origin_file_pattern == "":
|
| 33 |
+
return "*"
|
| 34 |
+
elif self.origin_file_pattern.endswith("/"):
|
| 35 |
+
return self.origin_file_pattern + "*"
|
| 36 |
+
else:
|
| 37 |
+
return self.origin_file_pattern
|
| 38 |
+
|
| 39 |
+
def parse_download_source(self):
|
| 40 |
+
if self.download_source is None:
|
| 41 |
+
if os.environ.get('DIFFSYNTH_DOWNLOAD_SOURCE') is not None:
|
| 42 |
+
return os.environ.get('DIFFSYNTH_DOWNLOAD_SOURCE')
|
| 43 |
+
else:
|
| 44 |
+
return "modelscope"
|
| 45 |
+
else:
|
| 46 |
+
return self.download_source
|
| 47 |
+
|
| 48 |
+
def parse_skip_download(self):
|
| 49 |
+
if self.skip_download is None:
|
| 50 |
+
if os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD') is not None:
|
| 51 |
+
if os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD').lower() == "true":
|
| 52 |
+
return True
|
| 53 |
+
elif os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD').lower() == "false":
|
| 54 |
+
return False
|
| 55 |
+
else:
|
| 56 |
+
return False
|
| 57 |
+
else:
|
| 58 |
+
return self.skip_download
|
| 59 |
+
|
| 60 |
+
def download(self):
|
| 61 |
+
origin_file_pattern = self.parse_original_file_pattern()
|
| 62 |
+
downloaded_files = glob.glob(origin_file_pattern, root_dir=os.path.join(self.local_model_path, self.model_id))
|
| 63 |
+
download_source = self.parse_download_source()
|
| 64 |
+
if download_source.lower() == "modelscope":
|
| 65 |
+
snapshot_download(
|
| 66 |
+
self.model_id,
|
| 67 |
+
local_dir=os.path.join(self.local_model_path, self.model_id),
|
| 68 |
+
allow_file_pattern=origin_file_pattern,
|
| 69 |
+
ignore_file_pattern=downloaded_files,
|
| 70 |
+
local_files_only=False
|
| 71 |
+
)
|
| 72 |
+
elif download_source.lower() == "huggingface":
|
| 73 |
+
hf_snapshot_download(
|
| 74 |
+
self.model_id,
|
| 75 |
+
local_dir=os.path.join(self.local_model_path, self.model_id),
|
| 76 |
+
allow_patterns=origin_file_pattern,
|
| 77 |
+
ignore_patterns=downloaded_files,
|
| 78 |
+
local_files_only=False
|
| 79 |
+
)
|
| 80 |
+
else:
|
| 81 |
+
raise ValueError("`download_source` should be `modelscope` or `huggingface`.")
|
| 82 |
+
|
| 83 |
+
def require_downloading(self):
|
| 84 |
+
if self.path is not None:
|
| 85 |
+
return False
|
| 86 |
+
skip_download = self.parse_skip_download()
|
| 87 |
+
return not skip_download
|
| 88 |
+
|
| 89 |
+
def reset_local_model_path(self):
|
| 90 |
+
if os.environ.get('DIFFSYNTH_MODEL_BASE_PATH') is not None:
|
| 91 |
+
self.local_model_path = os.environ.get('DIFFSYNTH_MODEL_BASE_PATH')
|
| 92 |
+
elif self.local_model_path is None:
|
| 93 |
+
self.local_model_path = "./models"
|
| 94 |
+
|
| 95 |
+
def download_if_necessary(self):
|
| 96 |
+
self.check_input()
|
| 97 |
+
self.reset_local_model_path()
|
| 98 |
+
if self.require_downloading():
|
| 99 |
+
self.download()
|
| 100 |
+
if self.origin_file_pattern is None or self.origin_file_pattern == "":
|
| 101 |
+
self.path = os.path.join(self.local_model_path, self.model_id)
|
| 102 |
+
else:
|
| 103 |
+
self.path = glob.glob(os.path.join(self.local_model_path, self.model_id, self.origin_file_pattern))
|
| 104 |
+
if isinstance(self.path, list) and len(self.path) == 1:
|
| 105 |
+
self.path = self.path[0]
|
| 106 |
+
|
| 107 |
+
def vram_config(self):
|
| 108 |
+
return {
|
| 109 |
+
"offload_device": self.offload_device,
|
| 110 |
+
"offload_dtype": self.offload_dtype,
|
| 111 |
+
"onload_device": self.onload_device,
|
| 112 |
+
"onload_dtype": self.onload_dtype,
|
| 113 |
+
"preparing_device": self.preparing_device,
|
| 114 |
+
"preparing_dtype": self.preparing_dtype,
|
| 115 |
+
"computation_device": self.computation_device,
|
| 116 |
+
"computation_dtype": self.computation_dtype,
|
| 117 |
+
}
|
diffsynth/core/loader/file.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from safetensors import safe_open
|
| 2 |
+
import torch, hashlib
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def load_state_dict(file_path, torch_dtype=None, device="cpu"):
|
| 6 |
+
if isinstance(file_path, list):
|
| 7 |
+
state_dict = {}
|
| 8 |
+
for file_path_ in file_path:
|
| 9 |
+
state_dict.update(load_state_dict(file_path_, torch_dtype, device))
|
| 10 |
+
return state_dict
|
| 11 |
+
if file_path.endswith(".safetensors"):
|
| 12 |
+
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype, device=device)
|
| 13 |
+
else:
|
| 14 |
+
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype, device=device)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"):
|
| 18 |
+
state_dict = {}
|
| 19 |
+
with safe_open(file_path, framework="pt", device=str(device)) as f:
|
| 20 |
+
for k in f.keys():
|
| 21 |
+
state_dict[k] = f.get_tensor(k)
|
| 22 |
+
if torch_dtype is not None:
|
| 23 |
+
state_dict[k] = state_dict[k].to(torch_dtype)
|
| 24 |
+
return state_dict
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def load_state_dict_from_bin(file_path, torch_dtype=None, device="cpu"):
|
| 28 |
+
state_dict = torch.load(file_path, map_location=device, weights_only=True)
|
| 29 |
+
if len(state_dict) == 1:
|
| 30 |
+
if "state_dict" in state_dict:
|
| 31 |
+
state_dict = state_dict["state_dict"]
|
| 32 |
+
elif "module" in state_dict:
|
| 33 |
+
state_dict = state_dict["module"]
|
| 34 |
+
elif "model_state" in state_dict:
|
| 35 |
+
state_dict = state_dict["model_state"]
|
| 36 |
+
if torch_dtype is not None:
|
| 37 |
+
for i in state_dict:
|
| 38 |
+
if isinstance(state_dict[i], torch.Tensor):
|
| 39 |
+
state_dict[i] = state_dict[i].to(torch_dtype)
|
| 40 |
+
return state_dict
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def convert_state_dict_keys_to_single_str(state_dict, with_shape=True):
|
| 44 |
+
keys = []
|
| 45 |
+
for key, value in state_dict.items():
|
| 46 |
+
if isinstance(key, str):
|
| 47 |
+
if isinstance(value, torch.Tensor):
|
| 48 |
+
if with_shape:
|
| 49 |
+
shape = "_".join(map(str, list(value.shape)))
|
| 50 |
+
keys.append(key + ":" + shape)
|
| 51 |
+
keys.append(key)
|
| 52 |
+
elif isinstance(value, dict):
|
| 53 |
+
keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape))
|
| 54 |
+
keys.sort()
|
| 55 |
+
keys_str = ",".join(keys)
|
| 56 |
+
return keys_str
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def hash_state_dict_keys(state_dict, with_shape=True):
|
| 60 |
+
keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
|
| 61 |
+
keys_str = keys_str.encode(encoding="UTF-8")
|
| 62 |
+
return hashlib.md5(keys_str).hexdigest()
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def load_keys_dict(file_path):
|
| 66 |
+
if isinstance(file_path, list):
|
| 67 |
+
state_dict = {}
|
| 68 |
+
for file_path_ in file_path:
|
| 69 |
+
state_dict.update(load_keys_dict(file_path_))
|
| 70 |
+
return state_dict
|
| 71 |
+
if file_path.endswith(".safetensors"):
|
| 72 |
+
return load_keys_dict_from_safetensors(file_path)
|
| 73 |
+
else:
|
| 74 |
+
return load_keys_dict_from_bin(file_path)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def load_keys_dict_from_safetensors(file_path):
|
| 78 |
+
keys_dict = {}
|
| 79 |
+
with safe_open(file_path, framework="pt", device="cpu") as f:
|
| 80 |
+
for k in f.keys():
|
| 81 |
+
keys_dict[k] = f.get_slice(k).get_shape()
|
| 82 |
+
return keys_dict
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def convert_state_dict_to_keys_dict(state_dict):
|
| 86 |
+
keys_dict = {}
|
| 87 |
+
for k, v in state_dict.items():
|
| 88 |
+
if isinstance(v, torch.Tensor):
|
| 89 |
+
keys_dict[k] = list(v.shape)
|
| 90 |
+
else:
|
| 91 |
+
keys_dict[k] = convert_state_dict_to_keys_dict(v)
|
| 92 |
+
return keys_dict
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def load_keys_dict_from_bin(file_path):
|
| 96 |
+
state_dict = load_state_dict_from_bin(file_path)
|
| 97 |
+
keys_dict = convert_state_dict_to_keys_dict(state_dict)
|
| 98 |
+
return keys_dict
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def convert_keys_dict_to_single_str(state_dict, with_shape=True):
|
| 102 |
+
keys = []
|
| 103 |
+
for key, value in state_dict.items():
|
| 104 |
+
if isinstance(key, str):
|
| 105 |
+
if isinstance(value, dict):
|
| 106 |
+
keys.append(key + "|" + convert_keys_dict_to_single_str(value, with_shape=with_shape))
|
| 107 |
+
else:
|
| 108 |
+
if with_shape:
|
| 109 |
+
shape = "_".join(map(str, list(value)))
|
| 110 |
+
keys.append(key + ":" + shape)
|
| 111 |
+
keys.append(key)
|
| 112 |
+
keys.sort()
|
| 113 |
+
keys_str = ",".join(keys)
|
| 114 |
+
return keys_str
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def hash_model_file(path, with_shape=True):
|
| 118 |
+
keys_dict = load_keys_dict(path)
|
| 119 |
+
keys_str = convert_keys_dict_to_single_str(keys_dict, with_shape=with_shape)
|
| 120 |
+
keys_str = keys_str.encode(encoding="UTF-8")
|
| 121 |
+
return hashlib.md5(keys_str).hexdigest()
|
diffsynth/core/loader/model.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ..vram.initialization import skip_model_initialization
|
| 2 |
+
from ..vram.disk_map import DiskMap
|
| 3 |
+
from ..vram.layers import enable_vram_management
|
| 4 |
+
from .file import load_state_dict
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, use_disk_map=False, module_map=None, vram_config=None, vram_limit=None):
|
| 9 |
+
config = {} if config is None else config
|
| 10 |
+
# Why do we use `skip_model_initialization`?
|
| 11 |
+
# It skips the random initialization of model parameters,
|
| 12 |
+
# thereby speeding up model loading and avoiding excessive memory usage.
|
| 13 |
+
with skip_model_initialization():
|
| 14 |
+
model = model_class(**config)
|
| 15 |
+
# What is `module_map`?
|
| 16 |
+
# This is a module mapping table for VRAM management.
|
| 17 |
+
if module_map is not None:
|
| 18 |
+
devices = [vram_config["offload_device"], vram_config["onload_device"], vram_config["preparing_device"], vram_config["computation_device"]]
|
| 19 |
+
device = [d for d in devices if d != "disk"][0]
|
| 20 |
+
dtypes = [vram_config["offload_dtype"], vram_config["onload_dtype"], vram_config["preparing_dtype"], vram_config["computation_dtype"]]
|
| 21 |
+
dtype = [d for d in dtypes if d != "disk"][0]
|
| 22 |
+
if vram_config["offload_device"] != "disk":
|
| 23 |
+
state_dict = DiskMap(path, device, torch_dtype=dtype)
|
| 24 |
+
if state_dict_converter is not None:
|
| 25 |
+
state_dict = state_dict_converter(state_dict)
|
| 26 |
+
else:
|
| 27 |
+
state_dict = {i: state_dict[i] for i in state_dict}
|
| 28 |
+
model.load_state_dict(state_dict, assign=True)
|
| 29 |
+
model = enable_vram_management(model, module_map, vram_config=vram_config, disk_map=None, vram_limit=vram_limit)
|
| 30 |
+
else:
|
| 31 |
+
disk_map = DiskMap(path, device, state_dict_converter=state_dict_converter)
|
| 32 |
+
model = enable_vram_management(model, module_map, vram_config=vram_config, disk_map=disk_map, vram_limit=vram_limit)
|
| 33 |
+
else:
|
| 34 |
+
# Why do we use `DiskMap`?
|
| 35 |
+
# Sometimes a model file contains multiple models,
|
| 36 |
+
# and DiskMap can load only the parameters of a single model,
|
| 37 |
+
# avoiding the need to load all parameters in the file.
|
| 38 |
+
if use_disk_map:
|
| 39 |
+
state_dict = DiskMap(path, device, torch_dtype=torch_dtype)
|
| 40 |
+
else:
|
| 41 |
+
state_dict = load_state_dict(path, torch_dtype, device)
|
| 42 |
+
# Why do we use `state_dict_converter`?
|
| 43 |
+
# Some models are saved in complex formats,
|
| 44 |
+
# and we need to convert the state dict into the appropriate format.
|
| 45 |
+
if state_dict_converter is not None:
|
| 46 |
+
state_dict = state_dict_converter(state_dict)
|
| 47 |
+
else:
|
| 48 |
+
state_dict = {i: state_dict[i] for i in state_dict}
|
| 49 |
+
model.load_state_dict(state_dict, assign=True)
|
| 50 |
+
# Why do we call `to()`?
|
| 51 |
+
# Because some models override the behavior of `to()`,
|
| 52 |
+
# especially those from libraries like Transformers.
|
| 53 |
+
model = model.to(dtype=torch_dtype, device=device)
|
| 54 |
+
if hasattr(model, "eval"):
|
| 55 |
+
model = model.eval()
|
| 56 |
+
return model
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def load_model_with_disk_offload(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, module_map=None):
|
| 60 |
+
if isinstance(path, str):
|
| 61 |
+
path = [path]
|
| 62 |
+
config = {} if config is None else config
|
| 63 |
+
with skip_model_initialization():
|
| 64 |
+
model = model_class(**config)
|
| 65 |
+
if hasattr(model, "eval"):
|
| 66 |
+
model = model.eval()
|
| 67 |
+
disk_map = DiskMap(path, device, state_dict_converter=state_dict_converter)
|
| 68 |
+
vram_config = {
|
| 69 |
+
"offload_dtype": "disk",
|
| 70 |
+
"offload_device": "disk",
|
| 71 |
+
"onload_dtype": "disk",
|
| 72 |
+
"onload_device": "disk",
|
| 73 |
+
"preparing_dtype": torch.float8_e4m3fn,
|
| 74 |
+
"preparing_device": device,
|
| 75 |
+
"computation_dtype": torch_dtype,
|
| 76 |
+
"computation_device": device,
|
| 77 |
+
}
|
| 78 |
+
enable_vram_management(model, module_map, vram_config=vram_config, disk_map=disk_map, vram_limit=80)
|
| 79 |
+
return model
|
diffsynth/core/vram/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .initialization import skip_model_initialization
|
| 2 |
+
from .layers import *
|
diffsynth/core/vram/disk_map.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from safetensors import safe_open
|
| 2 |
+
import torch, os
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class SafetensorsCompatibleTensor:
|
| 6 |
+
def __init__(self, tensor):
|
| 7 |
+
self.tensor = tensor
|
| 8 |
+
|
| 9 |
+
def get_shape(self):
|
| 10 |
+
return list(self.tensor.shape)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class SafetensorsCompatibleBinaryLoader:
|
| 14 |
+
def __init__(self, path, device):
|
| 15 |
+
print("Detected non-safetensors files, which may cause slower loading. It's recommended to convert it to a safetensors file.")
|
| 16 |
+
self.state_dict = torch.load(path, weights_only=True, map_location=device)
|
| 17 |
+
|
| 18 |
+
def keys(self):
|
| 19 |
+
return self.state_dict.keys()
|
| 20 |
+
|
| 21 |
+
def get_tensor(self, name):
|
| 22 |
+
return self.state_dict[name]
|
| 23 |
+
|
| 24 |
+
def get_slice(self, name):
|
| 25 |
+
return SafetensorsCompatibleTensor(self.state_dict[name])
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class DiskMap:
|
| 29 |
+
|
| 30 |
+
def __init__(self, path, device, torch_dtype=None, state_dict_converter=None, buffer_size=10**9):
|
| 31 |
+
self.path = path if isinstance(path, list) else [path]
|
| 32 |
+
self.device = device
|
| 33 |
+
self.torch_dtype = torch_dtype
|
| 34 |
+
if os.environ.get('DIFFSYNTH_DISK_MAP_BUFFER_SIZE') is not None:
|
| 35 |
+
self.buffer_size = int(os.environ.get('DIFFSYNTH_DISK_MAP_BUFFER_SIZE'))
|
| 36 |
+
else:
|
| 37 |
+
self.buffer_size = buffer_size
|
| 38 |
+
self.files = []
|
| 39 |
+
self.flush_files()
|
| 40 |
+
self.name_map = {}
|
| 41 |
+
for file_id, file in enumerate(self.files):
|
| 42 |
+
for name in file.keys():
|
| 43 |
+
self.name_map[name] = file_id
|
| 44 |
+
self.rename_dict = self.fetch_rename_dict(state_dict_converter)
|
| 45 |
+
|
| 46 |
+
def flush_files(self):
|
| 47 |
+
if len(self.files) == 0:
|
| 48 |
+
for path in self.path:
|
| 49 |
+
if path.endswith(".safetensors"):
|
| 50 |
+
self.files.append(safe_open(path, framework="pt", device=str(self.device)))
|
| 51 |
+
else:
|
| 52 |
+
self.files.append(SafetensorsCompatibleBinaryLoader(path, device=self.device))
|
| 53 |
+
else:
|
| 54 |
+
for i, path in enumerate(self.path):
|
| 55 |
+
if path.endswith(".safetensors"):
|
| 56 |
+
self.files[i] = safe_open(path, framework="pt", device=str(self.device))
|
| 57 |
+
self.num_params = 0
|
| 58 |
+
|
| 59 |
+
def __getitem__(self, name):
|
| 60 |
+
if self.rename_dict is not None: name = self.rename_dict[name]
|
| 61 |
+
file_id = self.name_map[name]
|
| 62 |
+
param = self.files[file_id].get_tensor(name)
|
| 63 |
+
if self.torch_dtype is not None and isinstance(param, torch.Tensor):
|
| 64 |
+
param = param.to(self.torch_dtype)
|
| 65 |
+
if isinstance(param, torch.Tensor) and param.device == "cpu":
|
| 66 |
+
param = param.clone()
|
| 67 |
+
if isinstance(param, torch.Tensor):
|
| 68 |
+
self.num_params += param.numel()
|
| 69 |
+
if self.num_params > self.buffer_size:
|
| 70 |
+
self.flush_files()
|
| 71 |
+
return param
|
| 72 |
+
|
| 73 |
+
def fetch_rename_dict(self, state_dict_converter):
|
| 74 |
+
if state_dict_converter is None:
|
| 75 |
+
return None
|
| 76 |
+
state_dict = {}
|
| 77 |
+
for file in self.files:
|
| 78 |
+
for name in file.keys():
|
| 79 |
+
state_dict[name] = name
|
| 80 |
+
state_dict = state_dict_converter(state_dict)
|
| 81 |
+
return state_dict
|
| 82 |
+
|
| 83 |
+
def __iter__(self):
|
| 84 |
+
if self.rename_dict is not None:
|
| 85 |
+
return self.rename_dict.__iter__()
|
| 86 |
+
else:
|
| 87 |
+
return self.name_map.__iter__()
|
| 88 |
+
|
| 89 |
+
def __contains__(self, x):
|
| 90 |
+
if self.rename_dict is not None:
|
| 91 |
+
return x in self.rename_dict
|
| 92 |
+
else:
|
| 93 |
+
return x in self.name_map
|
diffsynth/core/vram/initialization.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from contextlib import contextmanager
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
@contextmanager
|
| 6 |
+
def skip_model_initialization(device=torch.device("meta")):
|
| 7 |
+
|
| 8 |
+
def register_empty_parameter(module, name, param):
|
| 9 |
+
old_register_parameter(module, name, param)
|
| 10 |
+
if param is not None:
|
| 11 |
+
param_cls = type(module._parameters[name])
|
| 12 |
+
kwargs = module._parameters[name].__dict__
|
| 13 |
+
kwargs["requires_grad"] = param.requires_grad
|
| 14 |
+
module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
|
| 15 |
+
|
| 16 |
+
old_register_parameter = torch.nn.Module.register_parameter
|
| 17 |
+
torch.nn.Module.register_parameter = register_empty_parameter
|
| 18 |
+
try:
|
| 19 |
+
yield
|
| 20 |
+
finally:
|
| 21 |
+
torch.nn.Module.register_parameter = old_register_parameter
|
diffsynth/core/vram/layers.py
ADDED
|
@@ -0,0 +1,475 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch, copy
|
| 2 |
+
from typing import Union
|
| 3 |
+
from .initialization import skip_model_initialization
|
| 4 |
+
from .disk_map import DiskMap
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class AutoTorchModule(torch.nn.Module):
|
| 8 |
+
|
| 9 |
+
def __init__(
|
| 10 |
+
self,
|
| 11 |
+
offload_dtype: torch.dtype = None,
|
| 12 |
+
offload_device: Union[str, torch.device] = None,
|
| 13 |
+
onload_dtype: torch.dtype = None,
|
| 14 |
+
onload_device: Union[str, torch.device] = None,
|
| 15 |
+
preparing_dtype: torch.dtype = None,
|
| 16 |
+
preparing_device: Union[str, torch.device] = None,
|
| 17 |
+
computation_dtype: torch.dtype = None,
|
| 18 |
+
computation_device: Union[str, torch.device] = None,
|
| 19 |
+
vram_limit: float = None,
|
| 20 |
+
):
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.set_dtype_and_device(
|
| 23 |
+
offload_dtype,
|
| 24 |
+
offload_device,
|
| 25 |
+
onload_dtype,
|
| 26 |
+
onload_device,
|
| 27 |
+
preparing_dtype,
|
| 28 |
+
preparing_device,
|
| 29 |
+
computation_dtype,
|
| 30 |
+
computation_device,
|
| 31 |
+
vram_limit,
|
| 32 |
+
)
|
| 33 |
+
self.state = 0
|
| 34 |
+
self.name = ""
|
| 35 |
+
|
| 36 |
+
def set_dtype_and_device(
|
| 37 |
+
self,
|
| 38 |
+
offload_dtype: torch.dtype = None,
|
| 39 |
+
offload_device: Union[str, torch.device] = None,
|
| 40 |
+
onload_dtype: torch.dtype = None,
|
| 41 |
+
onload_device: Union[str, torch.device] = None,
|
| 42 |
+
preparing_dtype: torch.dtype = None,
|
| 43 |
+
preparing_device: Union[str, torch.device] = None,
|
| 44 |
+
computation_dtype: torch.dtype = None,
|
| 45 |
+
computation_device: Union[str, torch.device] = None,
|
| 46 |
+
vram_limit: float = None,
|
| 47 |
+
):
|
| 48 |
+
self.offload_dtype = offload_dtype or computation_dtype
|
| 49 |
+
self.offload_device = offload_device or computation_dtype
|
| 50 |
+
self.onload_dtype = onload_dtype or computation_dtype
|
| 51 |
+
self.onload_device = onload_device or computation_dtype
|
| 52 |
+
self.preparing_dtype = preparing_dtype or computation_dtype
|
| 53 |
+
self.preparing_device = preparing_device or computation_dtype
|
| 54 |
+
self.computation_dtype = computation_dtype
|
| 55 |
+
self.computation_device = computation_device
|
| 56 |
+
self.vram_limit = vram_limit
|
| 57 |
+
|
| 58 |
+
def cast_to(self, weight, dtype, device):
|
| 59 |
+
r = torch.empty_like(weight, dtype=dtype, device=device)
|
| 60 |
+
r.copy_(weight)
|
| 61 |
+
return r
|
| 62 |
+
|
| 63 |
+
def check_free_vram(self):
|
| 64 |
+
gpu_mem_state = torch.cuda.mem_get_info(self.computation_device)
|
| 65 |
+
used_memory = (gpu_mem_state[1] - gpu_mem_state[0]) / (1024**3)
|
| 66 |
+
return used_memory < self.vram_limit
|
| 67 |
+
|
| 68 |
+
def offload(self):
|
| 69 |
+
if self.state != 0:
|
| 70 |
+
self.to(dtype=self.offload_dtype, device=self.offload_device)
|
| 71 |
+
self.state = 0
|
| 72 |
+
|
| 73 |
+
def onload(self):
|
| 74 |
+
if self.state != 1:
|
| 75 |
+
self.to(dtype=self.onload_dtype, device=self.onload_device)
|
| 76 |
+
self.state = 1
|
| 77 |
+
|
| 78 |
+
def param_name(self, name):
|
| 79 |
+
if self.name == "":
|
| 80 |
+
return name
|
| 81 |
+
else:
|
| 82 |
+
return self.name + "." + name
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class AutoWrappedModule(AutoTorchModule):
|
| 86 |
+
|
| 87 |
+
def __init__(
|
| 88 |
+
self,
|
| 89 |
+
module: torch.nn.Module,
|
| 90 |
+
offload_dtype: torch.dtype = None,
|
| 91 |
+
offload_device: Union[str, torch.device] = None,
|
| 92 |
+
onload_dtype: torch.dtype = None,
|
| 93 |
+
onload_device: Union[str, torch.device] = None,
|
| 94 |
+
preparing_dtype: torch.dtype = None,
|
| 95 |
+
preparing_device: Union[str, torch.device] = None,
|
| 96 |
+
computation_dtype: torch.dtype = None,
|
| 97 |
+
computation_device: Union[str, torch.device] = None,
|
| 98 |
+
vram_limit: float = None,
|
| 99 |
+
name: str = "",
|
| 100 |
+
disk_map: DiskMap = None,
|
| 101 |
+
**kwargs
|
| 102 |
+
):
|
| 103 |
+
super().__init__(
|
| 104 |
+
offload_dtype,
|
| 105 |
+
offload_device,
|
| 106 |
+
onload_dtype,
|
| 107 |
+
onload_device,
|
| 108 |
+
preparing_dtype,
|
| 109 |
+
preparing_device,
|
| 110 |
+
computation_dtype,
|
| 111 |
+
computation_device,
|
| 112 |
+
vram_limit,
|
| 113 |
+
)
|
| 114 |
+
self.module = module
|
| 115 |
+
if offload_dtype == "disk":
|
| 116 |
+
self.name = name
|
| 117 |
+
self.disk_map = disk_map
|
| 118 |
+
self.required_params = [name for name, _ in self.module.named_parameters()]
|
| 119 |
+
self.disk_offload = True
|
| 120 |
+
else:
|
| 121 |
+
self.disk_offload = False
|
| 122 |
+
|
| 123 |
+
def load_from_disk(self, torch_dtype, device, copy_module=False):
|
| 124 |
+
if copy_module:
|
| 125 |
+
module = copy.deepcopy(self.module)
|
| 126 |
+
else:
|
| 127 |
+
module = self.module
|
| 128 |
+
state_dict = {}
|
| 129 |
+
for name in self.required_params:
|
| 130 |
+
param = self.disk_map[self.param_name(name)]
|
| 131 |
+
param = param.to(dtype=torch_dtype, device=device)
|
| 132 |
+
state_dict[name] = param
|
| 133 |
+
module.load_state_dict(state_dict, assign=True)
|
| 134 |
+
module.to(dtype=torch_dtype, device=device)
|
| 135 |
+
return module
|
| 136 |
+
|
| 137 |
+
def offload_to_disk(self, model: torch.nn.Module):
|
| 138 |
+
for buf in model.buffers():
|
| 139 |
+
# If there are some parameters are registed in buffers (not in state dict),
|
| 140 |
+
# We cannot offload the model.
|
| 141 |
+
for children in model.children():
|
| 142 |
+
self.offload_to_disk(children)
|
| 143 |
+
break
|
| 144 |
+
else:
|
| 145 |
+
model.to("meta")
|
| 146 |
+
|
| 147 |
+
def offload(self):
|
| 148 |
+
# offload / onload / preparing -> offload
|
| 149 |
+
if self.state != 0:
|
| 150 |
+
if self.disk_offload:
|
| 151 |
+
self.offload_to_disk(self.module)
|
| 152 |
+
else:
|
| 153 |
+
self.to(dtype=self.offload_dtype, device=self.offload_device)
|
| 154 |
+
self.state = 0
|
| 155 |
+
|
| 156 |
+
def onload(self):
|
| 157 |
+
# offload / onload / preparing -> onload
|
| 158 |
+
if self.state < 1:
|
| 159 |
+
if self.disk_offload and self.onload_device != "disk" and self.offload_device == "disk":
|
| 160 |
+
self.load_from_disk(self.onload_dtype, self.onload_device)
|
| 161 |
+
elif self.onload_device != "disk":
|
| 162 |
+
self.to(dtype=self.onload_dtype, device=self.onload_device)
|
| 163 |
+
self.state = 1
|
| 164 |
+
|
| 165 |
+
def preparing(self):
|
| 166 |
+
# onload / preparing -> preparing
|
| 167 |
+
if self.state != 2:
|
| 168 |
+
if self.disk_offload and self.preparing_device != "disk" and self.onload_device == "disk":
|
| 169 |
+
self.load_from_disk(self.preparing_dtype, self.preparing_device)
|
| 170 |
+
elif self.preparing_device != "disk":
|
| 171 |
+
self.to(dtype=self.preparing_dtype, device=self.preparing_device)
|
| 172 |
+
self.state = 2
|
| 173 |
+
|
| 174 |
+
def cast_to(self, module, dtype, device):
|
| 175 |
+
return copy.deepcopy(module).to(dtype=dtype, device=device)
|
| 176 |
+
|
| 177 |
+
def computation(self):
|
| 178 |
+
# onload / preparing -> computation (temporary)
|
| 179 |
+
if self.state == 2:
|
| 180 |
+
torch_dtype, device = self.preparing_dtype, self.preparing_device
|
| 181 |
+
else:
|
| 182 |
+
torch_dtype, device = self.onload_dtype, self.onload_device
|
| 183 |
+
if torch_dtype == self.computation_dtype and device == self.computation_device:
|
| 184 |
+
module = self.module
|
| 185 |
+
elif self.disk_offload and device == "disk":
|
| 186 |
+
module = self.load_from_disk(self.computation_dtype, self.computation_device, copy_module=True)
|
| 187 |
+
else:
|
| 188 |
+
module = self.cast_to(self.module, dtype=self.computation_dtype, device=self.computation_device)
|
| 189 |
+
return module
|
| 190 |
+
|
| 191 |
+
def forward(self, *args, **kwargs):
|
| 192 |
+
if self.state == 1 and (self.vram_limit is None or self.check_free_vram()):
|
| 193 |
+
self.preparing()
|
| 194 |
+
module = self.computation()
|
| 195 |
+
return module(*args, **kwargs)
|
| 196 |
+
|
| 197 |
+
def __getattr__(self, name):
|
| 198 |
+
if name in self.__dict__ or name == "module":
|
| 199 |
+
return super().__getattr__(name)
|
| 200 |
+
else:
|
| 201 |
+
return getattr(self.module, name)
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
class AutoWrappedNonRecurseModule(AutoWrappedModule):
|
| 205 |
+
|
| 206 |
+
def __init__(
|
| 207 |
+
self,
|
| 208 |
+
module: torch.nn.Module,
|
| 209 |
+
offload_dtype: torch.dtype = None,
|
| 210 |
+
offload_device: Union[str, torch.device] = None,
|
| 211 |
+
onload_dtype: torch.dtype = None,
|
| 212 |
+
onload_device: Union[str, torch.device] = None,
|
| 213 |
+
preparing_dtype: torch.dtype = None,
|
| 214 |
+
preparing_device: Union[str, torch.device] = None,
|
| 215 |
+
computation_dtype: torch.dtype = None,
|
| 216 |
+
computation_device: Union[str, torch.device] = None,
|
| 217 |
+
vram_limit: float = None,
|
| 218 |
+
name: str = "",
|
| 219 |
+
disk_map: DiskMap = None,
|
| 220 |
+
**kwargs
|
| 221 |
+
):
|
| 222 |
+
super().__init__(
|
| 223 |
+
module,
|
| 224 |
+
offload_dtype,
|
| 225 |
+
offload_device,
|
| 226 |
+
onload_dtype,
|
| 227 |
+
onload_device,
|
| 228 |
+
preparing_dtype,
|
| 229 |
+
preparing_device,
|
| 230 |
+
computation_dtype,
|
| 231 |
+
computation_device,
|
| 232 |
+
vram_limit,
|
| 233 |
+
name,
|
| 234 |
+
disk_map,
|
| 235 |
+
**kwargs
|
| 236 |
+
)
|
| 237 |
+
if self.disk_offload:
|
| 238 |
+
self.required_params = [name for name, _ in self.module.named_parameters(recurse=False)]
|
| 239 |
+
|
| 240 |
+
def load_from_disk(self, torch_dtype, device, copy_module=False):
|
| 241 |
+
if copy_module:
|
| 242 |
+
module = copy.deepcopy(self.module)
|
| 243 |
+
else:
|
| 244 |
+
module = self.module
|
| 245 |
+
state_dict = {}
|
| 246 |
+
for name in self.required_params:
|
| 247 |
+
param = self.disk_map[self.param_name(name)]
|
| 248 |
+
param = param.to(dtype=torch_dtype, device=device)
|
| 249 |
+
state_dict[name] = param
|
| 250 |
+
module.load_state_dict(state_dict, assign=True, strict=False)
|
| 251 |
+
return module
|
| 252 |
+
|
| 253 |
+
def offload_to_disk(self, model: torch.nn.Module):
|
| 254 |
+
for name in self.required_params:
|
| 255 |
+
getattr(self, name).to("meta")
|
| 256 |
+
|
| 257 |
+
def cast_to(self, module, dtype, device):
|
| 258 |
+
# Parameter casting is implemented in the model architecture.
|
| 259 |
+
return module
|
| 260 |
+
|
| 261 |
+
def __getattr__(self, name):
|
| 262 |
+
if name in self.__dict__ or name == "module":
|
| 263 |
+
return super().__getattr__(name)
|
| 264 |
+
else:
|
| 265 |
+
return getattr(self.module, name)
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
|
| 269 |
+
def __init__(
|
| 270 |
+
self,
|
| 271 |
+
module: torch.nn.Linear,
|
| 272 |
+
offload_dtype: torch.dtype = None,
|
| 273 |
+
offload_device: Union[str, torch.device] = None,
|
| 274 |
+
onload_dtype: torch.dtype = None,
|
| 275 |
+
onload_device: Union[str, torch.device] = None,
|
| 276 |
+
preparing_dtype: torch.dtype = None,
|
| 277 |
+
preparing_device: Union[str, torch.device] = None,
|
| 278 |
+
computation_dtype: torch.dtype = None,
|
| 279 |
+
computation_device: Union[str, torch.device] = None,
|
| 280 |
+
vram_limit: float = None,
|
| 281 |
+
name: str = "",
|
| 282 |
+
disk_map: DiskMap = None,
|
| 283 |
+
**kwargs
|
| 284 |
+
):
|
| 285 |
+
with skip_model_initialization():
|
| 286 |
+
super().__init__(
|
| 287 |
+
in_features=module.in_features,
|
| 288 |
+
out_features=module.out_features,
|
| 289 |
+
bias=module.bias is not None,
|
| 290 |
+
)
|
| 291 |
+
self.set_dtype_and_device(
|
| 292 |
+
offload_dtype,
|
| 293 |
+
offload_device,
|
| 294 |
+
onload_dtype,
|
| 295 |
+
onload_device,
|
| 296 |
+
preparing_dtype,
|
| 297 |
+
preparing_device,
|
| 298 |
+
computation_dtype,
|
| 299 |
+
computation_device,
|
| 300 |
+
vram_limit,
|
| 301 |
+
)
|
| 302 |
+
self.weight = module.weight
|
| 303 |
+
self.bias = module.bias
|
| 304 |
+
self.state = 0
|
| 305 |
+
self.name = name
|
| 306 |
+
self.lora_A_weights = []
|
| 307 |
+
self.lora_B_weights = []
|
| 308 |
+
self.lora_merger = None
|
| 309 |
+
self.enable_fp8 = computation_dtype in [torch.float8_e4m3fn, torch.float8_e4m3fnuz]
|
| 310 |
+
|
| 311 |
+
if offload_dtype == "disk":
|
| 312 |
+
self.disk_map = disk_map
|
| 313 |
+
self.disk_offload = True
|
| 314 |
+
else:
|
| 315 |
+
self.disk_offload = False
|
| 316 |
+
|
| 317 |
+
def fp8_linear(
|
| 318 |
+
self,
|
| 319 |
+
input: torch.Tensor,
|
| 320 |
+
weight: torch.Tensor,
|
| 321 |
+
bias: torch.Tensor = None,
|
| 322 |
+
) -> torch.Tensor:
|
| 323 |
+
device = input.device
|
| 324 |
+
origin_dtype = input.dtype
|
| 325 |
+
origin_shape = input.shape
|
| 326 |
+
input = input.reshape(-1, origin_shape[-1])
|
| 327 |
+
|
| 328 |
+
x_max = torch.max(torch.abs(input), dim=-1, keepdim=True).values
|
| 329 |
+
fp8_max = 448.0
|
| 330 |
+
# For float8_e4m3fnuz, the maximum representable value is half of that of e4m3fn.
|
| 331 |
+
# To avoid overflow and ensure numerical compatibility during FP8 computation,
|
| 332 |
+
# we scale down the input by 2.0 in advance.
|
| 333 |
+
# This scaling will be compensated later during the final result scaling.
|
| 334 |
+
if self.computation_dtype == torch.float8_e4m3fnuz:
|
| 335 |
+
fp8_max = fp8_max / 2.0
|
| 336 |
+
scale_a = torch.clamp(x_max / fp8_max, min=1.0).float().to(device=device)
|
| 337 |
+
scale_b = torch.ones((weight.shape[0], 1)).to(device=device)
|
| 338 |
+
input = input / (scale_a + 1e-8)
|
| 339 |
+
input = input.to(self.computation_dtype)
|
| 340 |
+
weight = weight.to(self.computation_dtype)
|
| 341 |
+
bias = bias.to(torch.bfloat16)
|
| 342 |
+
|
| 343 |
+
result = torch._scaled_mm(
|
| 344 |
+
input,
|
| 345 |
+
weight.T,
|
| 346 |
+
scale_a=scale_a,
|
| 347 |
+
scale_b=scale_b.T,
|
| 348 |
+
bias=bias,
|
| 349 |
+
out_dtype=origin_dtype,
|
| 350 |
+
)
|
| 351 |
+
new_shape = origin_shape[:-1] + result.shape[-1:]
|
| 352 |
+
result = result.reshape(new_shape)
|
| 353 |
+
return result
|
| 354 |
+
|
| 355 |
+
def load_from_disk(self, torch_dtype, device, assign=True):
|
| 356 |
+
weight = self.disk_map[self.name + ".weight"].to(dtype=torch_dtype, device=device)
|
| 357 |
+
bias = None if self.bias is None else self.disk_map[self.name + ".bias"].to(dtype=torch_dtype, device=device)
|
| 358 |
+
if assign:
|
| 359 |
+
state_dict = {"weight": weight}
|
| 360 |
+
if bias is not None: state_dict["bias"] = bias
|
| 361 |
+
self.load_state_dict(state_dict, assign=True)
|
| 362 |
+
return weight, bias
|
| 363 |
+
|
| 364 |
+
def offload(self):
|
| 365 |
+
# offload / onload / preparing -> offload
|
| 366 |
+
if self.state != 0:
|
| 367 |
+
if self.disk_offload:
|
| 368 |
+
self.to("meta")
|
| 369 |
+
else:
|
| 370 |
+
self.to(dtype=self.offload_dtype, device=self.offload_device)
|
| 371 |
+
self.state = 0
|
| 372 |
+
|
| 373 |
+
def onload(self):
|
| 374 |
+
# offload / onload / preparing -> onload
|
| 375 |
+
if self.state < 1:
|
| 376 |
+
if self.disk_offload and self.onload_device != "disk" and self.offload_device == "disk":
|
| 377 |
+
self.load_from_disk(self.onload_dtype, self.onload_device)
|
| 378 |
+
elif self.onload_device != "disk":
|
| 379 |
+
self.to(dtype=self.onload_dtype, device=self.onload_device)
|
| 380 |
+
self.state = 1
|
| 381 |
+
|
| 382 |
+
def preparing(self):
|
| 383 |
+
# onload / preparing -> preparing
|
| 384 |
+
if self.state != 2:
|
| 385 |
+
if self.disk_offload and self.preparing_device != "disk" and self.onload_device == "disk":
|
| 386 |
+
self.load_from_disk(self.preparing_dtype, self.preparing_device)
|
| 387 |
+
elif self.preparing_device != "disk":
|
| 388 |
+
self.to(dtype=self.preparing_dtype, device=self.preparing_device)
|
| 389 |
+
self.state = 2
|
| 390 |
+
|
| 391 |
+
def computation(self):
|
| 392 |
+
# onload / preparing -> computation (temporary)
|
| 393 |
+
if self.state == 2:
|
| 394 |
+
torch_dtype, device = self.preparing_dtype, self.preparing_device
|
| 395 |
+
else:
|
| 396 |
+
torch_dtype, device = self.onload_dtype, self.onload_device
|
| 397 |
+
if torch_dtype == self.computation_dtype and device == self.computation_device:
|
| 398 |
+
weight, bias = self.weight, self.bias
|
| 399 |
+
elif self.disk_offload and device == "disk":
|
| 400 |
+
weight, bias = self.load_from_disk(self.computation_dtype, self.computation_device, assign=False)
|
| 401 |
+
else:
|
| 402 |
+
weight = self.cast_to(self.weight, self.computation_dtype, self.computation_device)
|
| 403 |
+
bias = None if self.bias is None else self.cast_to(self.bias, self.computation_dtype, self.computation_device)
|
| 404 |
+
return weight, bias
|
| 405 |
+
|
| 406 |
+
def linear_forward(self, x, weight, bias):
|
| 407 |
+
if self.enable_fp8:
|
| 408 |
+
out = self.fp8_linear(x, weight, bias)
|
| 409 |
+
else:
|
| 410 |
+
out = torch.nn.functional.linear(x, weight, bias)
|
| 411 |
+
return out
|
| 412 |
+
|
| 413 |
+
def lora_forward(self, x, out):
|
| 414 |
+
if self.lora_merger is None:
|
| 415 |
+
for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights):
|
| 416 |
+
out = out + x @ lora_A.T @ lora_B.T
|
| 417 |
+
else:
|
| 418 |
+
lora_output = []
|
| 419 |
+
for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights):
|
| 420 |
+
lora_output.append(x @ lora_A.T @ lora_B.T)
|
| 421 |
+
lora_output = torch.stack(lora_output)
|
| 422 |
+
out = self.lora_merger(out, lora_output)
|
| 423 |
+
return out
|
| 424 |
+
|
| 425 |
+
def forward(self, x, *args, **kwargs):
|
| 426 |
+
if self.state == 1 and (self.vram_limit is None or self.check_free_vram()):
|
| 427 |
+
self.preparing()
|
| 428 |
+
weight, bias = self.computation()
|
| 429 |
+
out = self.linear_forward(x, weight, bias)
|
| 430 |
+
if len(self.lora_A_weights) > 0:
|
| 431 |
+
out = self.lora_forward(x, out)
|
| 432 |
+
return out
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, vram_config: dict, vram_limit=None, name_prefix="", disk_map=None, **kwargs):
|
| 436 |
+
if isinstance(model, AutoWrappedNonRecurseModule):
|
| 437 |
+
model = model.module
|
| 438 |
+
for name, module in model.named_children():
|
| 439 |
+
layer_name = name if name_prefix == "" else name_prefix + "." + name
|
| 440 |
+
for source_module, target_module in module_map.items():
|
| 441 |
+
if isinstance(module, source_module):
|
| 442 |
+
module_ = target_module(module, **vram_config, vram_limit=vram_limit, name=layer_name, disk_map=disk_map, **kwargs)
|
| 443 |
+
if isinstance(module_, AutoWrappedNonRecurseModule):
|
| 444 |
+
enable_vram_management_recursively(module_, module_map, vram_config, vram_limit=vram_limit, name_prefix=layer_name, disk_map=disk_map, **kwargs)
|
| 445 |
+
setattr(model, name, module_)
|
| 446 |
+
break
|
| 447 |
+
else:
|
| 448 |
+
enable_vram_management_recursively(module, module_map, vram_config, vram_limit=vram_limit, name_prefix=layer_name, disk_map=disk_map, **kwargs)
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
def fill_vram_config(model, vram_config):
|
| 452 |
+
vram_config_ = vram_config.copy()
|
| 453 |
+
vram_config_["onload_dtype"] = vram_config["computation_dtype"]
|
| 454 |
+
vram_config_["onload_device"] = vram_config["computation_device"]
|
| 455 |
+
vram_config_["preparing_dtype"] = vram_config["computation_dtype"]
|
| 456 |
+
vram_config_["preparing_device"] = vram_config["computation_device"]
|
| 457 |
+
for k in vram_config:
|
| 458 |
+
if vram_config[k] != vram_config_[k]:
|
| 459 |
+
print(f"No fine-grained VRAM configuration is provided for {model.__class__.__name__}. [`onload`, `preparing`, `computation`] will be the same state. `vram_config` is set to {vram_config_}")
|
| 460 |
+
break
|
| 461 |
+
return vram_config_
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
def enable_vram_management(model: torch.nn.Module, module_map: dict, vram_config: dict, vram_limit=None, disk_map=None, **kwargs):
|
| 465 |
+
for source_module, target_module in module_map.items():
|
| 466 |
+
# If no fine-grained VRAM configuration is provided, the entire model will be managed uniformly.
|
| 467 |
+
if isinstance(model, source_module):
|
| 468 |
+
vram_config = fill_vram_config(model, vram_config)
|
| 469 |
+
model = target_module(model, **vram_config, vram_limit=vram_limit, disk_map=disk_map, **kwargs)
|
| 470 |
+
break
|
| 471 |
+
else:
|
| 472 |
+
enable_vram_management_recursively(model, module_map, vram_config, vram_limit=vram_limit, disk_map=disk_map, **kwargs)
|
| 473 |
+
# `vram_management_enabled` is a flag that allows the pipeline to determine whether VRAM management is enabled.
|
| 474 |
+
model.vram_management_enabled = True
|
| 475 |
+
return model
|
diffsynth/datasets/mvdataset.py
ADDED
|
@@ -0,0 +1,393 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import imageio, os, torch, warnings, torchvision, argparse, json
|
| 2 |
+
from peft import LoraConfig, inject_adapter_in_model
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import pandas as pd
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
from accelerate import Accelerator
|
| 7 |
+
from accelerate.utils import DistributedDataParallelKwargs
|
| 8 |
+
import random
|
| 9 |
+
from decord import VideoReader
|
| 10 |
+
from decord import cpu, gpu
|
| 11 |
+
import imageio.v3 as iio
|
| 12 |
+
|
| 13 |
+
from torchvision import transforms
|
| 14 |
+
import torchvision
|
| 15 |
+
import random
|
| 16 |
+
import decord
|
| 17 |
+
from torchvision import transforms
|
| 18 |
+
import re
|
| 19 |
+
decord.bridge.set_bridge('torch')
|
| 20 |
+
import random
|
| 21 |
+
import numpy as np
|
| 22 |
+
from PIL import Image, ImageOps
|
| 23 |
+
|
| 24 |
+
class MulltiShot_MultiView_Dataset(torch.utils.data.Dataset):
|
| 25 |
+
def __init__(self, dataset_base_path='/root/paddlejob/workspace/qizipeng/baidu/personal-code/Multi-view/multi_view/datasets/merged_mark_paishe_ds_meg_merge_dwposefilter_paishe.json',
|
| 26 |
+
ref_image_path='/root/paddlejob/workspace/qizipeng/code/longvideogen/output.json',
|
| 27 |
+
time_division_factor=4,
|
| 28 |
+
time_division_remainder=1,
|
| 29 |
+
max_pixels=1920*1080,
|
| 30 |
+
height_division_factor=16, width_division_factor=16,
|
| 31 |
+
transform=None,
|
| 32 |
+
length=None,
|
| 33 |
+
resolution=None,
|
| 34 |
+
prev_length=5,
|
| 35 |
+
ref_num = 3,
|
| 36 |
+
training = True):
|
| 37 |
+
self.data_path = dataset_base_path
|
| 38 |
+
self.data = []
|
| 39 |
+
self.length = length
|
| 40 |
+
self.resolution = resolution
|
| 41 |
+
self.height, self.width = resolution
|
| 42 |
+
self.num_frames = length
|
| 43 |
+
self.time_division_factor = time_division_factor
|
| 44 |
+
self.time_division_remainder = time_division_remainder
|
| 45 |
+
self.max_pixels = max_pixels
|
| 46 |
+
self.height_division_factor = height_division_factor
|
| 47 |
+
self.width_division_factor = width_division_factor
|
| 48 |
+
self.prev_length = prev_length
|
| 49 |
+
self.training = training
|
| 50 |
+
self.ref_num = ref_num
|
| 51 |
+
|
| 52 |
+
with open(self.data_path, 'r') as f:
|
| 53 |
+
meta_datas = json.load(f)
|
| 54 |
+
|
| 55 |
+
for video_path in tqdm(meta_datas.keys()):
|
| 56 |
+
context = meta_datas[video_path]
|
| 57 |
+
candidate_labels = list(context.keys())
|
| 58 |
+
candidate_labels.remove('text')
|
| 59 |
+
|
| 60 |
+
disk_path = meta_datas[video_path]["disk_path"]
|
| 61 |
+
if not disk_path.lower().endswith(".mp4"):
|
| 62 |
+
continue
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
# reader = imageio.get_reader(meta_datas[video_path]["disk_path"])
|
| 66 |
+
# total_original_frames = reader.count_frames()
|
| 67 |
+
# total_frame = total_original_frames # context["end_index"] - context["start_index"] - 1
|
| 68 |
+
total_frame = None
|
| 69 |
+
ref_id = self.get_ref_id(face_crop_angle = context['facedetect_v1'], facedetect_v1_frame_index = context['facedetect_v1_frame_index'], total_frame = total_frame)
|
| 70 |
+
if ref_id == []:
|
| 71 |
+
continue
|
| 72 |
+
ref_id_all = []
|
| 73 |
+
for ids in ref_id:
|
| 74 |
+
ref_id_grop = []
|
| 75 |
+
for id in ids:
|
| 76 |
+
coordinate = context['facedetect_v1'][id][0]['detect']
|
| 77 |
+
if context['facedetect_v1'][id][0]['detect']["prob"] < 0.99:
|
| 78 |
+
continue
|
| 79 |
+
top, height, width, left = coordinate['top'], coordinate['height'], coordinate['width'], coordinate['left']
|
| 80 |
+
if not(min(height, width) > 80 ):
|
| 81 |
+
continue
|
| 82 |
+
# enlarge bbox 1.5x
|
| 83 |
+
width = int(width * 1)
|
| 84 |
+
height = int(height * 1)
|
| 85 |
+
frame_index = context['facedetect_v1_frame_index'][id]
|
| 86 |
+
ref_id_grop.append([top, height, width, left, int(frame_index)])
|
| 87 |
+
if ref_id_grop != []:
|
| 88 |
+
if len(ref_id_grop) >= 3: #self.ref_num: ### 为了和ref_num = 3 保持数据一致
|
| 89 |
+
ref_id_all.append(ref_id_grop)
|
| 90 |
+
if ref_id_all == []:
|
| 91 |
+
continue
|
| 92 |
+
meta_prompt = {}
|
| 93 |
+
meta_prompt["global_caption"] = None
|
| 94 |
+
meta_prompt["per_shot_prompt"] = []
|
| 95 |
+
meta_prompt["single_prompt"] = context['text']
|
| 96 |
+
self.data.append({'video_path': disk_path, 'meta_prompt': meta_prompt, "ref_id_all": ref_id_all})
|
| 97 |
+
# self.data.append({'video_path':video_path, 'meta_prompt': meta_prompt, "ref_id_all": ref_id_all})
|
| 98 |
+
|
| 99 |
+
random.seed(42) # 让每次划分一致(可选)
|
| 100 |
+
total = len(self.data)
|
| 101 |
+
test_count = max(1, int(total * 0.05)) # 至少一个
|
| 102 |
+
|
| 103 |
+
# 随机选择 test 的 index
|
| 104 |
+
test_indices = set(random.sample(range(total), test_count))
|
| 105 |
+
|
| 106 |
+
self.data_test = [self.data[i] for i in range(total) if i in test_indices]
|
| 107 |
+
self.data_train = [self.data[i] for i in range(total) if i not in test_indices]
|
| 108 |
+
print(f"🔥 数据集划分完成:Train={len(self.data_train)}, Test={len(self.data_test)}")
|
| 109 |
+
|
| 110 |
+
if self.height is not None and self.width is not None:
|
| 111 |
+
print("Height and width are fixed. Setting `dynamic_resolution` to False.")
|
| 112 |
+
self.dynamic_resolution = False
|
| 113 |
+
elif self.height is None and self.width is None:
|
| 114 |
+
print("Height and width are none. Setting `dynamic_resolution` to True.")
|
| 115 |
+
self.dynamic_resolution = True
|
| 116 |
+
|
| 117 |
+
def get_ref_id(self, face_crop_angle, facedetect_v1_frame_index = None, total_frame = None, angle_threshold=50):
|
| 118 |
+
"""
|
| 119 |
+
返回满足角度差异要求的三元组 [i, j, k]
|
| 120 |
+
要求:
|
| 121 |
+
- face_crop_angle[i] / [j] / [k] 都必须非空
|
| 122 |
+
- i,j 两者任意 yaw/pitch/roll 差值 > angle_threshold
|
| 123 |
+
- k != i != j,且 k 也必须非空
|
| 124 |
+
"""
|
| 125 |
+
ref_id = []
|
| 126 |
+
max_try = 5
|
| 127 |
+
need_max = 3
|
| 128 |
+
try_num = 0
|
| 129 |
+
|
| 130 |
+
# 过滤空元素,保留有效索引
|
| 131 |
+
valid_indices = [idx for idx, item in enumerate(face_crop_angle) if item]
|
| 132 |
+
N = len(valid_indices)
|
| 133 |
+
|
| 134 |
+
if N < 3:
|
| 135 |
+
return ref_id # 不足 3 张有效图,无法组成三元组
|
| 136 |
+
|
| 137 |
+
# 两两组合检查角度差
|
| 138 |
+
for a in range(N - 1):
|
| 139 |
+
i = valid_indices[a]
|
| 140 |
+
# if facedetect_v1_frame_index[i] > total_frame:
|
| 141 |
+
# continue
|
| 142 |
+
angle_i = face_crop_angle[i][0]["angle"]
|
| 143 |
+
|
| 144 |
+
for b in range(a + 1, N):
|
| 145 |
+
j = valid_indices[b]
|
| 146 |
+
# if facedetect_v1_frame_index[j] > total_frame:
|
| 147 |
+
# continue
|
| 148 |
+
angle_j = face_crop_angle[j][0]["angle"]
|
| 149 |
+
|
| 150 |
+
# 判断是否满足阈值
|
| 151 |
+
if (
|
| 152 |
+
abs(angle_i["yaw"] - angle_j["yaw"]) > angle_threshold or
|
| 153 |
+
abs(angle_i["pitch"] - angle_j["pitch"]) > angle_threshold or
|
| 154 |
+
abs(angle_i["roll"] - angle_j["roll"]) > angle_threshold
|
| 155 |
+
):
|
| 156 |
+
# 找第三个 k
|
| 157 |
+
for c in range(N):
|
| 158 |
+
k = valid_indices[c]
|
| 159 |
+
# if facedetect_v1_frame_index[k] > total_frame:
|
| 160 |
+
# continue
|
| 161 |
+
if k != i and k != j:
|
| 162 |
+
ref_id.append([i, j, k])
|
| 163 |
+
break
|
| 164 |
+
|
| 165 |
+
try_num += 1
|
| 166 |
+
if try_num >= max_try or len(ref_id) >= need_max:
|
| 167 |
+
return ref_id
|
| 168 |
+
|
| 169 |
+
return ref_id
|
| 170 |
+
def crop_and_resize(self, image, target_height, target_width):
|
| 171 |
+
width, height = image.size
|
| 172 |
+
scale = max(target_width / width, target_height / height)
|
| 173 |
+
image = torchvision.transforms.functional.resize(
|
| 174 |
+
image,
|
| 175 |
+
(round(height*scale), round(width*scale)),
|
| 176 |
+
interpolation=torchvision.transforms.InterpolationMode.BILINEAR
|
| 177 |
+
)
|
| 178 |
+
image = torchvision.transforms.functional.center_crop(image, (target_height, target_width))
|
| 179 |
+
return image
|
| 180 |
+
|
| 181 |
+
def get_height_width(self, image):
|
| 182 |
+
if self.dynamic_resolution:
|
| 183 |
+
width, height = image.size
|
| 184 |
+
if width * height > self.max_pixels:
|
| 185 |
+
scale = (width * height / self.max_pixels) ** 0.5
|
| 186 |
+
height, width = int(height / scale), int(width / scale)
|
| 187 |
+
height = height // self.height_division_factor * self.height_division_factor
|
| 188 |
+
width = width // self.width_division_factor * self.width_division_factor
|
| 189 |
+
else:
|
| 190 |
+
height, width = self.height, self.width
|
| 191 |
+
return height, width
|
| 192 |
+
|
| 193 |
+
# def
|
| 194 |
+
# img_ratio = img.width / img.height
|
| 195 |
+
# target_ratio = w / h
|
| 196 |
+
# if img_ratio > target_ratio: # Image is wider than target
|
| 197 |
+
# new_width = w
|
| 198 |
+
# new_height = int(new_width / img_ratio)
|
| 199 |
+
# else: # Image is taller than target
|
| 200 |
+
# new_height = h
|
| 201 |
+
# new_width = int(new_height * img_ratio)
|
| 202 |
+
|
| 203 |
+
# # img = img.resize((new_width, new_height), Image.ANTIALIAS)
|
| 204 |
+
# img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
| 205 |
+
|
| 206 |
+
# # Create a new image with the target size and place the resized image in the center
|
| 207 |
+
# delta_w = w - img.size[0]
|
| 208 |
+
# delta_h = h - img.size[1]
|
| 209 |
+
# padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2))
|
| 210 |
+
# new_img = ImageOps.expand(img, padding, fill=(255, 255, 255))
|
| 211 |
+
|
| 212 |
+
def resize_ref(self, img, target_h, target_w):
|
| 213 |
+
h = target_h
|
| 214 |
+
w = target_w
|
| 215 |
+
img = img.convert("RGB")
|
| 216 |
+
# Calculate the required size to keep aspect ratio and fill the rest with padding.
|
| 217 |
+
img_ratio = img.width / img.height
|
| 218 |
+
target_ratio = w / h
|
| 219 |
+
|
| 220 |
+
if img_ratio > target_ratio: # Image is wider than target
|
| 221 |
+
new_width = w
|
| 222 |
+
new_height = int(new_width / img_ratio)
|
| 223 |
+
else: # Image is taller than target
|
| 224 |
+
new_height = h
|
| 225 |
+
new_width = int(new_height * img_ratio)
|
| 226 |
+
|
| 227 |
+
# img = img.resize((new_width, new_height), Image.ANTIALIAS)
|
| 228 |
+
img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
| 229 |
+
|
| 230 |
+
# Create a new image with the target size and place the resized image in the center
|
| 231 |
+
delta_w = w - img.size[0]
|
| 232 |
+
delta_h = h - img.size[1]
|
| 233 |
+
padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2))
|
| 234 |
+
new_img = ImageOps.expand(img, padding, fill=(255, 255, 255))
|
| 235 |
+
|
| 236 |
+
return new_img
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def load_video_crop_ref_image(self, video_path=None, ref_id_all=[[]]):
|
| 240 |
+
### fps 转化
|
| 241 |
+
reader = imageio.get_reader(video_path)
|
| 242 |
+
meta = reader.get_meta_data()
|
| 243 |
+
original_fps = meta.get("fps", 24)
|
| 244 |
+
target_fps = 16
|
| 245 |
+
duration_seconds = 5
|
| 246 |
+
target_frames = target_fps * duration_seconds + 1 # = 80 frames
|
| 247 |
+
|
| 248 |
+
# ---- 获取原视频帧数 ----
|
| 249 |
+
try:
|
| 250 |
+
total_original_frames = reader.count_frames()
|
| 251 |
+
except:
|
| 252 |
+
total_original_frames = int(meta.get("duration", 5) * original_fps)
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
# ---- 需要多少原始帧(5秒)----
|
| 257 |
+
need_orig_frames = int(original_fps * duration_seconds)
|
| 258 |
+
|
| 259 |
+
# ---- Case 1: 原视频 >= 5秒 → 随机选择 5 秒起点 ----
|
| 260 |
+
if total_original_frames > need_orig_frames:
|
| 261 |
+
max_start = total_original_frames - need_orig_frames
|
| 262 |
+
start_frame = random.randint(0, max_start)
|
| 263 |
+
segment_start = start_frame
|
| 264 |
+
segment_end = start_frame + need_orig_frames
|
| 265 |
+
else:
|
| 266 |
+
# ---- Case 2: 原视频 < 5秒 → 用全部帧 ----
|
| 267 |
+
segment_start = 0
|
| 268 |
+
segment_end = total_original_frames
|
| 269 |
+
|
| 270 |
+
# ---- 均匀采样 80 帧 ----
|
| 271 |
+
sample_ids = np.linspace(segment_start, segment_end - 1, num=target_frames, dtype=int)
|
| 272 |
+
|
| 273 |
+
frames = []
|
| 274 |
+
for frame_id in sample_ids:
|
| 275 |
+
frame = reader.get_data(int(frame_id))
|
| 276 |
+
frame = Image.fromarray(frame)
|
| 277 |
+
frame = self.crop_and_resize(frame, *self.get_height_width(frame))
|
| 278 |
+
frames.append(frame)
|
| 279 |
+
|
| 280 |
+
# ===========================
|
| 281 |
+
# 选择参考图部分(你要求的)
|
| 282 |
+
# ===========================
|
| 283 |
+
|
| 284 |
+
# 1)从 ref_images_all(三维 list)里随机选一组
|
| 285 |
+
# ref_images_all = [ [img1, img2, img3], [imgA, imgB, imgC], ... ]
|
| 286 |
+
ref_group = random.choice(ref_id_all)
|
| 287 |
+
|
| 288 |
+
# 2)检查资源是否足够
|
| 289 |
+
if len(ref_group) < self.ref_num:
|
| 290 |
+
raise ValueError(f"需要 {self.ref_num} 张参考图,但该组只有 {len(ref_group)} 张。")
|
| 291 |
+
|
| 292 |
+
# 3)从该组中随机选 self.ref_num 张
|
| 293 |
+
selected_refs = random.sample(ref_group, self.ref_num)
|
| 294 |
+
random.shuffle(selected_refs)
|
| 295 |
+
|
| 296 |
+
ref_images = []
|
| 297 |
+
for sf in selected_refs:
|
| 298 |
+
top, height, width, left, frame_index = sf
|
| 299 |
+
# import pdb; pdb.set_trace()
|
| 300 |
+
if frame_index > total_original_frames:
|
| 301 |
+
print(f"{video_path}, frame_index({frame_index}) out of range")
|
| 302 |
+
frame = reader.get_data(int(frame_index))
|
| 303 |
+
frame = Image.fromarray(frame)
|
| 304 |
+
xmin, ymin, xmax, ymax = left, top, left + width, top + height
|
| 305 |
+
cropped_image = frame.crop((xmin, ymin, xmax, ymax)).convert("RGB")
|
| 306 |
+
cropped_image = self.resize_ref(cropped_image, self.height, self.width)
|
| 307 |
+
# Calculate the required size to keep aspect ratio and fill the rest with padding.
|
| 308 |
+
ref_images.append(cropped_image)
|
| 309 |
+
reader.close()
|
| 310 |
+
|
| 311 |
+
return frames, ref_images
|
| 312 |
+
|
| 313 |
+
def __getitem__(self, index):
|
| 314 |
+
max_retry = 10 # 最多重试 10 次,避免死循环
|
| 315 |
+
retry = 0
|
| 316 |
+
|
| 317 |
+
while retry < max_retry:
|
| 318 |
+
# ----- 选择 train / test 数据 -----
|
| 319 |
+
if self.training:
|
| 320 |
+
meta_data = self.data_train[index % len(self.data_train)]
|
| 321 |
+
else:
|
| 322 |
+
meta_data = self.data_test[index % len(self.data_test)]
|
| 323 |
+
|
| 324 |
+
video_path = meta_data['video_path']
|
| 325 |
+
meta_prompt = meta_data['meta_prompt']
|
| 326 |
+
ref_id_all = meta_data['ref_id_all']
|
| 327 |
+
|
| 328 |
+
# ----- 尝试读取 video + ref -----
|
| 329 |
+
try:
|
| 330 |
+
input_video, ref_images = self.load_video_crop_ref_image(
|
| 331 |
+
video_path=video_path,
|
| 332 |
+
ref_id_all=ref_id_all
|
| 333 |
+
)
|
| 334 |
+
except Exception as e:
|
| 335 |
+
print("❌ Exception in load_video_crop_ref_image")
|
| 336 |
+
print(f" video_path: {video_path}")
|
| 337 |
+
print(f" error type: {type(e).__name__}")
|
| 338 |
+
print(f" error msg : {e}")
|
| 339 |
+
|
| 340 |
+
# 打印 traceback,定位问题更容易
|
| 341 |
+
import traceback
|
| 342 |
+
traceback.print_exc()
|
| 343 |
+
input_video = None
|
| 344 |
+
ref_images = None
|
| 345 |
+
# ----- 如果成功,并且 video 不为空,返回结果 -----
|
| 346 |
+
if input_video is not None and len(input_video) > 0:
|
| 347 |
+
return {
|
| 348 |
+
"global_caption": None,
|
| 349 |
+
"shot_num": 1,
|
| 350 |
+
"pre_shot_caption": [],
|
| 351 |
+
"single_caption": meta_prompt["single_prompt"],
|
| 352 |
+
"video": input_video,
|
| 353 |
+
"ref_num": self.ref_num,
|
| 354 |
+
"ref_images": ref_images,
|
| 355 |
+
"video_path": video_path
|
| 356 |
+
}
|
| 357 |
+
|
| 358 |
+
# ----- 如果失败,换 index,并继续尝试 -----
|
| 359 |
+
retry += 1
|
| 360 |
+
index = random.randint(0, len(self.data_train) - 1 if self.training else len(self.data_test) - 1)
|
| 361 |
+
|
| 362 |
+
# 若 10 次都失败,返回最后一次的错误内容
|
| 363 |
+
raise RuntimeError(f"❌ [Dataset] Failed to load video/ref after {max_retry} retries.")
|
| 364 |
+
|
| 365 |
+
def __len__(self):
|
| 366 |
+
if self.training:
|
| 367 |
+
return len(self.data_train)
|
| 368 |
+
else:
|
| 369 |
+
return len(self.data_test)
|
| 370 |
+
|
| 371 |
+
if __name__ == '__main__':
|
| 372 |
+
from torch.utils.data import DataLoader
|
| 373 |
+
dataset = MulltiShot_MultiView_Dataset(length=49, resolution=(384, 640), training=True)
|
| 374 |
+
print(len(dataset))
|
| 375 |
+
metadata = dataset[0]
|
| 376 |
+
# results = dataset[0]
|
| 377 |
+
# loader = DataLoader(
|
| 378 |
+
# dataset,
|
| 379 |
+
# batch_size=1, # 视频一般 batch=1
|
| 380 |
+
# shuffle=False, # 你想打乱就 True
|
| 381 |
+
# num_workers=10, # ⭐ 重点:开启 8 个子进程并行加载
|
| 382 |
+
# pin_memory=True,
|
| 383 |
+
# prefetch_factor=2, # 每个 worker 预读取 2 个样本
|
| 384 |
+
# collate_fn=lambda x: x[0], # ⭐ 不做任何 collate
|
| 385 |
+
# )
|
| 386 |
+
|
| 387 |
+
# for batch in tqdm(loader):
|
| 388 |
+
# pass
|
| 389 |
+
for i in tqdm(range(len(dataset))):
|
| 390 |
+
file = dataset[i]
|
| 391 |
+
|
| 392 |
+
assert 0
|
| 393 |
+
|
diffsynth/diffusion/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .flow_match import FlowMatchScheduler
|
| 2 |
+
from .training_module import DiffusionTrainingModule
|
| 3 |
+
from .logger import ModelLogger
|
| 4 |
+
from .runner import launch_training_task, launch_data_process_task
|
| 5 |
+
from .parsers import *
|
| 6 |
+
from .loss import *
|
diffsynth/diffusion/base_pipeline.py
ADDED
|
@@ -0,0 +1,439 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from PIL import Image
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
from einops import repeat, reduce
|
| 5 |
+
from typing import Union
|
| 6 |
+
from ..core import AutoTorchModule, AutoWrappedLinear, load_state_dict, ModelConfig
|
| 7 |
+
from ..utils.lora import GeneralLoRALoader
|
| 8 |
+
from ..models.model_loader import ModelPool
|
| 9 |
+
from ..utils.controlnet import ControlNetInput
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class PipelineUnit:
|
| 13 |
+
def __init__(
|
| 14 |
+
self,
|
| 15 |
+
seperate_cfg: bool = False,
|
| 16 |
+
take_over: bool = False,
|
| 17 |
+
input_params: tuple[str] = None,
|
| 18 |
+
output_params: tuple[str] = None,
|
| 19 |
+
input_params_posi: dict[str, str] = None,
|
| 20 |
+
input_params_nega: dict[str, str] = None,
|
| 21 |
+
onload_model_names: tuple[str] = None
|
| 22 |
+
):
|
| 23 |
+
self.seperate_cfg = seperate_cfg
|
| 24 |
+
self.take_over = take_over
|
| 25 |
+
self.input_params = input_params
|
| 26 |
+
self.output_params = output_params
|
| 27 |
+
self.input_params_posi = input_params_posi
|
| 28 |
+
self.input_params_nega = input_params_nega
|
| 29 |
+
self.onload_model_names = onload_model_names
|
| 30 |
+
|
| 31 |
+
def fetch_input_params(self):
|
| 32 |
+
params = []
|
| 33 |
+
if self.input_params is not None:
|
| 34 |
+
for param in self.input_params:
|
| 35 |
+
params.append(param)
|
| 36 |
+
if self.input_params_posi is not None:
|
| 37 |
+
for _, param in self.input_params_posi.items():
|
| 38 |
+
params.append(param)
|
| 39 |
+
if self.input_params_nega is not None:
|
| 40 |
+
for _, param in self.input_params_nega.items():
|
| 41 |
+
params.append(param)
|
| 42 |
+
params = sorted(list(set(params)))
|
| 43 |
+
return params
|
| 44 |
+
|
| 45 |
+
def fetch_output_params(self):
|
| 46 |
+
params = []
|
| 47 |
+
if self.output_params is not None:
|
| 48 |
+
for param in self.output_params:
|
| 49 |
+
params.append(param)
|
| 50 |
+
return params
|
| 51 |
+
|
| 52 |
+
def process(self, pipe, **kwargs) -> dict:
|
| 53 |
+
return {}
|
| 54 |
+
|
| 55 |
+
def post_process(self, pipe, **kwargs) -> dict:
|
| 56 |
+
return {}
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class BasePipeline(torch.nn.Module):
|
| 60 |
+
|
| 61 |
+
def __init__(
|
| 62 |
+
self,
|
| 63 |
+
device="cuda", torch_dtype=torch.float16,
|
| 64 |
+
height_division_factor=64, width_division_factor=64,
|
| 65 |
+
time_division_factor=None, time_division_remainder=None,
|
| 66 |
+
):
|
| 67 |
+
super().__init__()
|
| 68 |
+
# The device and torch_dtype is used for the storage of intermediate variables, not models.
|
| 69 |
+
self.device = device
|
| 70 |
+
self.torch_dtype = torch_dtype
|
| 71 |
+
# The following parameters are used for shape check.
|
| 72 |
+
self.height_division_factor = height_division_factor
|
| 73 |
+
self.width_division_factor = width_division_factor
|
| 74 |
+
self.time_division_factor = time_division_factor
|
| 75 |
+
self.time_division_remainder = time_division_remainder
|
| 76 |
+
# VRAM management
|
| 77 |
+
self.vram_management_enabled = False
|
| 78 |
+
# Pipeline Unit Runner
|
| 79 |
+
self.unit_runner = PipelineUnitRunner()
|
| 80 |
+
# LoRA Loader
|
| 81 |
+
self.lora_loader = GeneralLoRALoader
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def to(self, *args, **kwargs):
|
| 85 |
+
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
|
| 86 |
+
if device is not None:
|
| 87 |
+
self.device = device
|
| 88 |
+
if dtype is not None:
|
| 89 |
+
self.torch_dtype = dtype
|
| 90 |
+
super().to(*args, **kwargs)
|
| 91 |
+
return self
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def check_resize_height_width(self, height, width, num_frames=None):
|
| 95 |
+
# Shape check
|
| 96 |
+
if height % self.height_division_factor != 0:
|
| 97 |
+
height = (height + self.height_division_factor - 1) // self.height_division_factor * self.height_division_factor
|
| 98 |
+
print(f"height % {self.height_division_factor} != 0. We round it up to {height}.")
|
| 99 |
+
if width % self.width_division_factor != 0:
|
| 100 |
+
width = (width + self.width_division_factor - 1) // self.width_division_factor * self.width_division_factor
|
| 101 |
+
print(f"width % {self.width_division_factor} != 0. We round it up to {width}.")
|
| 102 |
+
if num_frames is None:
|
| 103 |
+
return height, width
|
| 104 |
+
else:
|
| 105 |
+
if num_frames % self.time_division_factor != self.time_division_remainder:
|
| 106 |
+
num_frames = (num_frames + self.time_division_factor - 1) // self.time_division_factor * self.time_division_factor + self.time_division_remainder
|
| 107 |
+
print(f"num_frames % {self.time_division_factor} != {self.time_division_remainder}. We round it up to {num_frames}.")
|
| 108 |
+
return height, width, num_frames
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def preprocess_image(self, image, torch_dtype=None, device=None, pattern="B C H W", min_value=-1, max_value=1):
|
| 112 |
+
# Transform a PIL.Image to torch.Tensor
|
| 113 |
+
image = torch.Tensor(np.array(image, dtype=np.float32))
|
| 114 |
+
image = image.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device)
|
| 115 |
+
image = image * ((max_value - min_value) / 255) + min_value
|
| 116 |
+
image = repeat(image, f"H W C -> {pattern}", **({"B": 1} if "B" in pattern else {}))
|
| 117 |
+
return image
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def preprocess_video(self, video, torch_dtype=None, device=None, pattern="B C T H W", min_value=-1, max_value=1):
|
| 121 |
+
# Transform a list of PIL.Image to torch.Tensor
|
| 122 |
+
video = [self.preprocess_image(image, torch_dtype=torch_dtype, device=device, min_value=min_value, max_value=max_value) for image in video]
|
| 123 |
+
video = torch.stack(video, dim=pattern.index("T") // 2)
|
| 124 |
+
return video
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def vae_output_to_image(self, vae_output, pattern="B C H W", min_value=-1, max_value=1):
|
| 128 |
+
# Transform a torch.Tensor to PIL.Image
|
| 129 |
+
if pattern != "H W C":
|
| 130 |
+
vae_output = reduce(vae_output, f"{pattern} -> H W C", reduction="mean")
|
| 131 |
+
image = ((vae_output - min_value) * (255 / (max_value - min_value))).clip(0, 255)
|
| 132 |
+
image = image.to(device="cpu", dtype=torch.uint8)
|
| 133 |
+
image = Image.fromarray(image.numpy())
|
| 134 |
+
return image
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def vae_output_to_video(self, vae_output, pattern="B C T H W", min_value=-1, max_value=1):
|
| 138 |
+
# Transform a torch.Tensor to list of PIL.Image
|
| 139 |
+
if pattern != "T H W C":
|
| 140 |
+
vae_output = reduce(vae_output, f"{pattern} -> T H W C", reduction="mean")
|
| 141 |
+
video = [self.vae_output_to_image(image, pattern="H W C", min_value=min_value, max_value=max_value) for image in vae_output]
|
| 142 |
+
return video
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def load_models_to_device(self, model_names):
|
| 146 |
+
if self.vram_management_enabled:
|
| 147 |
+
# offload models
|
| 148 |
+
for name, model in self.named_children():
|
| 149 |
+
if name not in model_names:
|
| 150 |
+
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
|
| 151 |
+
if hasattr(model, "offload"):
|
| 152 |
+
model.offload()
|
| 153 |
+
else:
|
| 154 |
+
for module in model.modules():
|
| 155 |
+
if hasattr(module, "offload"):
|
| 156 |
+
module.offload()
|
| 157 |
+
torch.cuda.empty_cache()
|
| 158 |
+
# onload models
|
| 159 |
+
for name, model in self.named_children():
|
| 160 |
+
if name in model_names:
|
| 161 |
+
if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
|
| 162 |
+
if hasattr(model, "onload"):
|
| 163 |
+
model.onload()
|
| 164 |
+
else:
|
| 165 |
+
for module in model.modules():
|
| 166 |
+
if hasattr(module, "onload"):
|
| 167 |
+
module.onload()
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def generate_noise(self, shape, seed=None, rand_device="cpu", rand_torch_dtype=torch.float32, device=None, torch_dtype=None):
|
| 171 |
+
# Initialize Gaussian noise
|
| 172 |
+
generator = None if seed is None else torch.Generator(rand_device).manual_seed(seed)
|
| 173 |
+
noise = torch.randn(shape, generator=generator, device=rand_device, dtype=rand_torch_dtype)
|
| 174 |
+
noise = noise.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device)
|
| 175 |
+
return noise
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def get_vram(self):
|
| 179 |
+
return torch.cuda.mem_get_info(self.device)[1] / (1024 ** 3)
|
| 180 |
+
|
| 181 |
+
def get_module(self, model, name):
|
| 182 |
+
if "." in name:
|
| 183 |
+
name, suffix = name[:name.index(".")], name[name.index(".") + 1:]
|
| 184 |
+
if name.isdigit():
|
| 185 |
+
return self.get_module(model[int(name)], suffix)
|
| 186 |
+
else:
|
| 187 |
+
return self.get_module(getattr(model, name), suffix)
|
| 188 |
+
else:
|
| 189 |
+
return getattr(model, name)
|
| 190 |
+
|
| 191 |
+
def freeze_except(self, model_names):
|
| 192 |
+
self.eval()
|
| 193 |
+
self.requires_grad_(False)
|
| 194 |
+
for name in model_names:
|
| 195 |
+
module = self.get_module(self, name)
|
| 196 |
+
if module is None:
|
| 197 |
+
print(f"No {name} models in the pipeline. We cannot enable training on the model. If this occurs during the data processing stage, it is normal.")
|
| 198 |
+
continue
|
| 199 |
+
module.train()
|
| 200 |
+
module.requires_grad_(True)
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def blend_with_mask(self, base, addition, mask):
|
| 204 |
+
return base * (1 - mask) + addition * mask
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def step(self, scheduler, latents, progress_id, noise_pred, input_latents=None, inpaint_mask=None, **kwargs):
|
| 208 |
+
timestep = scheduler.timesteps[progress_id]
|
| 209 |
+
if inpaint_mask is not None:
|
| 210 |
+
noise_pred_expected = scheduler.return_to_timestep(scheduler.timesteps[progress_id], latents, input_latents)
|
| 211 |
+
noise_pred = self.blend_with_mask(noise_pred_expected, noise_pred, inpaint_mask)
|
| 212 |
+
latents_next = scheduler.step(noise_pred, timestep, latents)
|
| 213 |
+
return latents_next
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def split_pipeline_units(self, model_names: list[str]):
|
| 217 |
+
return PipelineUnitGraph().split_pipeline_units(self.units, model_names)
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def flush_vram_management_device(self, device):
|
| 221 |
+
for module in self.modules():
|
| 222 |
+
if isinstance(module, AutoTorchModule):
|
| 223 |
+
module.offload_device = device
|
| 224 |
+
module.onload_device = device
|
| 225 |
+
module.preparing_device = device
|
| 226 |
+
module.computation_device = device
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def load_lora(
|
| 230 |
+
self,
|
| 231 |
+
module: torch.nn.Module,
|
| 232 |
+
lora_config: Union[ModelConfig, str] = None,
|
| 233 |
+
alpha=1,
|
| 234 |
+
hotload=None,
|
| 235 |
+
state_dict=None,
|
| 236 |
+
):
|
| 237 |
+
if state_dict is None:
|
| 238 |
+
if isinstance(lora_config, str):
|
| 239 |
+
lora = load_state_dict(lora_config, torch_dtype=self.torch_dtype, device=self.device)
|
| 240 |
+
else:
|
| 241 |
+
lora_config.download_if_necessary()
|
| 242 |
+
lora = load_state_dict(lora_config.path, torch_dtype=self.torch_dtype, device=self.device)
|
| 243 |
+
else:
|
| 244 |
+
lora = state_dict
|
| 245 |
+
lora_loader = self.lora_loader(torch_dtype=self.torch_dtype, device=self.device)
|
| 246 |
+
lora = lora_loader.convert_state_dict(lora)
|
| 247 |
+
if hotload is None:
|
| 248 |
+
hotload = hasattr(module, "vram_management_enabled") and getattr(module, "vram_management_enabled")
|
| 249 |
+
if hotload:
|
| 250 |
+
if not (hasattr(module, "vram_management_enabled") and getattr(module, "vram_management_enabled")):
|
| 251 |
+
raise ValueError("VRAM Management is not enabled. LoRA hotloading is not supported.")
|
| 252 |
+
updated_num = 0
|
| 253 |
+
for _, module in module.named_modules():
|
| 254 |
+
if isinstance(module, AutoWrappedLinear):
|
| 255 |
+
name = module.name
|
| 256 |
+
lora_a_name = f'{name}.lora_A.weight'
|
| 257 |
+
lora_b_name = f'{name}.lora_B.weight'
|
| 258 |
+
if lora_a_name in lora and lora_b_name in lora:
|
| 259 |
+
updated_num += 1
|
| 260 |
+
module.lora_A_weights.append(lora[lora_a_name] * alpha)
|
| 261 |
+
module.lora_B_weights.append(lora[lora_b_name])
|
| 262 |
+
print(f"{updated_num} tensors are patched by LoRA. You can use `pipe.clear_lora()` to clear all LoRA layers.")
|
| 263 |
+
else:
|
| 264 |
+
lora_loader.fuse_lora_to_base_model(module, lora, alpha=alpha)
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
def clear_lora(self):
|
| 268 |
+
cleared_num = 0
|
| 269 |
+
for name, module in self.named_modules():
|
| 270 |
+
if isinstance(module, AutoWrappedLinear):
|
| 271 |
+
if hasattr(module, "lora_A_weights"):
|
| 272 |
+
if len(module.lora_A_weights) > 0:
|
| 273 |
+
cleared_num += 1
|
| 274 |
+
module.lora_A_weights.clear()
|
| 275 |
+
if hasattr(module, "lora_B_weights"):
|
| 276 |
+
module.lora_B_weights.clear()
|
| 277 |
+
print(f"{cleared_num} LoRA layers are cleared.")
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def download_and_load_models(self, model_configs: list[ModelConfig] = [], vram_limit: float = None):
|
| 281 |
+
model_pool = ModelPool()
|
| 282 |
+
for model_config in model_configs:
|
| 283 |
+
model_config.download_if_necessary()
|
| 284 |
+
vram_config = model_config.vram_config()
|
| 285 |
+
vram_config["computation_dtype"] = vram_config["computation_dtype"] or self.torch_dtype
|
| 286 |
+
vram_config["computation_device"] = vram_config["computation_device"] or self.device
|
| 287 |
+
model_pool.auto_load_model(
|
| 288 |
+
model_config.path,
|
| 289 |
+
vram_config=vram_config,
|
| 290 |
+
vram_limit=vram_limit,
|
| 291 |
+
clear_parameters=model_config.clear_parameters,
|
| 292 |
+
)
|
| 293 |
+
return model_pool
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
def check_vram_management_state(self):
|
| 297 |
+
vram_management_enabled = False
|
| 298 |
+
for module in self.children():
|
| 299 |
+
if hasattr(module, "vram_management_enabled") and getattr(module, "vram_management_enabled"):
|
| 300 |
+
vram_management_enabled = True
|
| 301 |
+
return vram_management_enabled
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
def cfg_guided_model_fn(self, model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega, **inputs_others):
|
| 305 |
+
noise_pred_posi = model_fn(**inputs_posi, **inputs_shared, **inputs_others)
|
| 306 |
+
if cfg_scale != 1.0:
|
| 307 |
+
noise_pred_nega = model_fn(**inputs_nega, **inputs_shared, **inputs_others)
|
| 308 |
+
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
| 309 |
+
else:
|
| 310 |
+
noise_pred = noise_pred_posi
|
| 311 |
+
return noise_pred
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
class PipelineUnitGraph:
|
| 315 |
+
def __init__(self):
|
| 316 |
+
pass
|
| 317 |
+
|
| 318 |
+
def build_edges(self, units: list[PipelineUnit]):
|
| 319 |
+
# Establish dependencies between units
|
| 320 |
+
# to search for subsequent related computation units.
|
| 321 |
+
last_compute_unit_id = {}
|
| 322 |
+
edges = []
|
| 323 |
+
for unit_id, unit in enumerate(units):
|
| 324 |
+
for input_param in unit.fetch_input_params():
|
| 325 |
+
if input_param in last_compute_unit_id:
|
| 326 |
+
edges.append((last_compute_unit_id[input_param], unit_id))
|
| 327 |
+
for output_param in unit.fetch_output_params():
|
| 328 |
+
last_compute_unit_id[output_param] = unit_id
|
| 329 |
+
return edges
|
| 330 |
+
|
| 331 |
+
def build_chains(self, units: list[PipelineUnit]):
|
| 332 |
+
# Establish updating chains for each variable
|
| 333 |
+
# to track their computation process.
|
| 334 |
+
params = sum([unit.fetch_input_params() + unit.fetch_output_params() for unit in units], [])
|
| 335 |
+
params = sorted(list(set(params)))
|
| 336 |
+
chains = {param: [] for param in params}
|
| 337 |
+
for unit_id, unit in enumerate(units):
|
| 338 |
+
for param in unit.fetch_output_params():
|
| 339 |
+
chains[param].append(unit_id)
|
| 340 |
+
return chains
|
| 341 |
+
|
| 342 |
+
def search_direct_unit_ids(self, units: list[PipelineUnit], model_names: list[str]):
|
| 343 |
+
# Search for units that directly participate in the model's computation.
|
| 344 |
+
related_unit_ids = []
|
| 345 |
+
for unit_id, unit in enumerate(units):
|
| 346 |
+
for model_name in model_names:
|
| 347 |
+
if unit.onload_model_names is not None and model_name in unit.onload_model_names:
|
| 348 |
+
related_unit_ids.append(unit_id)
|
| 349 |
+
break
|
| 350 |
+
return related_unit_ids
|
| 351 |
+
|
| 352 |
+
def search_related_unit_ids(self, edges, start_unit_ids, direction="target"):
|
| 353 |
+
# Search for subsequent related computation units.
|
| 354 |
+
related_unit_ids = [unit_id for unit_id in start_unit_ids]
|
| 355 |
+
while True:
|
| 356 |
+
neighbors = []
|
| 357 |
+
for source, target in edges:
|
| 358 |
+
if direction == "target" and source in related_unit_ids and target not in related_unit_ids:
|
| 359 |
+
neighbors.append(target)
|
| 360 |
+
elif direction == "source" and source not in related_unit_ids and target in related_unit_ids:
|
| 361 |
+
neighbors.append(source)
|
| 362 |
+
neighbors = sorted(list(set(neighbors)))
|
| 363 |
+
if len(neighbors) == 0:
|
| 364 |
+
break
|
| 365 |
+
else:
|
| 366 |
+
related_unit_ids.extend(neighbors)
|
| 367 |
+
related_unit_ids = sorted(list(set(related_unit_ids)))
|
| 368 |
+
return related_unit_ids
|
| 369 |
+
|
| 370 |
+
def search_updating_unit_ids(self, units: list[PipelineUnit], chains, related_unit_ids):
|
| 371 |
+
# If the input parameters of this subgraph are updated outside the subgraph,
|
| 372 |
+
# search for the units where these updates occur.
|
| 373 |
+
first_compute_unit_id = {}
|
| 374 |
+
for unit_id in related_unit_ids:
|
| 375 |
+
for param in units[unit_id].fetch_input_params():
|
| 376 |
+
if param not in first_compute_unit_id:
|
| 377 |
+
first_compute_unit_id[param] = unit_id
|
| 378 |
+
updating_unit_ids = []
|
| 379 |
+
for param in first_compute_unit_id:
|
| 380 |
+
unit_id = first_compute_unit_id[param]
|
| 381 |
+
chain = chains[param]
|
| 382 |
+
if unit_id in chain and chain.index(unit_id) != len(chain) - 1:
|
| 383 |
+
for unit_id_ in chain[chain.index(unit_id) + 1:]:
|
| 384 |
+
if unit_id_ not in related_unit_ids:
|
| 385 |
+
updating_unit_ids.append(unit_id_)
|
| 386 |
+
related_unit_ids.extend(updating_unit_ids)
|
| 387 |
+
related_unit_ids = sorted(list(set(related_unit_ids)))
|
| 388 |
+
return related_unit_ids
|
| 389 |
+
|
| 390 |
+
def split_pipeline_units(self, units: list[PipelineUnit], model_names: list[str]):
|
| 391 |
+
# Split the computation graph,
|
| 392 |
+
# separating all model-related computations.
|
| 393 |
+
related_unit_ids = self.search_direct_unit_ids(units, model_names)
|
| 394 |
+
edges = self.build_edges(units)
|
| 395 |
+
chains = self.build_chains(units)
|
| 396 |
+
while True:
|
| 397 |
+
num_related_unit_ids = len(related_unit_ids)
|
| 398 |
+
related_unit_ids = self.search_related_unit_ids(edges, related_unit_ids, "target")
|
| 399 |
+
related_unit_ids = self.search_updating_unit_ids(units, chains, related_unit_ids)
|
| 400 |
+
if len(related_unit_ids) == num_related_unit_ids:
|
| 401 |
+
break
|
| 402 |
+
else:
|
| 403 |
+
num_related_unit_ids = len(related_unit_ids)
|
| 404 |
+
related_units = [units[i] for i in related_unit_ids]
|
| 405 |
+
unrelated_units = [units[i] for i in range(len(units)) if i not in related_unit_ids]
|
| 406 |
+
return related_units, unrelated_units
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
class PipelineUnitRunner:
|
| 410 |
+
def __init__(self):
|
| 411 |
+
pass
|
| 412 |
+
|
| 413 |
+
def __call__(self, unit: PipelineUnit, pipe: BasePipeline, inputs_shared: dict, inputs_posi: dict, inputs_nega: dict) -> tuple[dict, dict]:
|
| 414 |
+
if unit.take_over:
|
| 415 |
+
# Let the pipeline unit take over this function.
|
| 416 |
+
inputs_shared, inputs_posi, inputs_nega = unit.process(pipe, inputs_shared=inputs_shared, inputs_posi=inputs_posi, inputs_nega=inputs_nega)
|
| 417 |
+
elif unit.seperate_cfg:
|
| 418 |
+
# Positive side
|
| 419 |
+
processor_inputs = {name: inputs_posi.get(name_) for name, name_ in unit.input_params_posi.items()}
|
| 420 |
+
if unit.input_params is not None:
|
| 421 |
+
for name in unit.input_params:
|
| 422 |
+
processor_inputs[name] = inputs_shared.get(name)
|
| 423 |
+
processor_outputs = unit.process(pipe, **processor_inputs)
|
| 424 |
+
inputs_posi.update(processor_outputs)
|
| 425 |
+
# Negative side
|
| 426 |
+
if inputs_shared["cfg_scale"] != 1:
|
| 427 |
+
processor_inputs = {name: inputs_nega.get(name_) for name, name_ in unit.input_params_nega.items()}
|
| 428 |
+
if unit.input_params is not None:
|
| 429 |
+
for name in unit.input_params:
|
| 430 |
+
processor_inputs[name] = inputs_shared.get(name)
|
| 431 |
+
processor_outputs = unit.process(pipe, **processor_inputs)
|
| 432 |
+
inputs_nega.update(processor_outputs)
|
| 433 |
+
else:
|
| 434 |
+
inputs_nega.update(processor_outputs)
|
| 435 |
+
else:
|
| 436 |
+
processor_inputs = {name: inputs_shared.get(name) for name in unit.input_params}
|
| 437 |
+
processor_outputs = unit.process(pipe, **processor_inputs)
|
| 438 |
+
inputs_shared.update(processor_outputs)
|
| 439 |
+
return inputs_shared, inputs_posi, inputs_nega
|
diffsynth/diffusion/flow_match.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch, math
|
| 2 |
+
from typing_extensions import Literal
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class FlowMatchScheduler():
|
| 6 |
+
|
| 7 |
+
def __init__(self, template: Literal["FLUX.1", "Wan", "Qwen-Image", "FLUX.2", "Z-Image"] = "FLUX.1"):
|
| 8 |
+
self.set_timesteps_fn = {
|
| 9 |
+
"FLUX.1": FlowMatchScheduler.set_timesteps_flux,
|
| 10 |
+
"Wan": FlowMatchScheduler.set_timesteps_wan,
|
| 11 |
+
"Qwen-Image": FlowMatchScheduler.set_timesteps_qwen_image,
|
| 12 |
+
"FLUX.2": FlowMatchScheduler.set_timesteps_flux2,
|
| 13 |
+
"Z-Image": FlowMatchScheduler.set_timesteps_z_image,
|
| 14 |
+
}.get(template, FlowMatchScheduler.set_timesteps_flux)
|
| 15 |
+
self.num_train_timesteps = 1000
|
| 16 |
+
|
| 17 |
+
@staticmethod
|
| 18 |
+
def set_timesteps_flux(num_inference_steps=100, denoising_strength=1.0, shift=None):
|
| 19 |
+
sigma_min = 0.003/1.002
|
| 20 |
+
sigma_max = 1.0
|
| 21 |
+
shift = 3 if shift is None else shift
|
| 22 |
+
num_train_timesteps = 1000
|
| 23 |
+
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
|
| 24 |
+
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps)
|
| 25 |
+
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
| 26 |
+
timesteps = sigmas * num_train_timesteps
|
| 27 |
+
return sigmas, timesteps
|
| 28 |
+
|
| 29 |
+
@staticmethod
|
| 30 |
+
def set_timesteps_wan(num_inference_steps=100, denoising_strength=1.0, shift=None):
|
| 31 |
+
sigma_min = 0.0
|
| 32 |
+
sigma_max = 1.0
|
| 33 |
+
shift = 5 if shift is None else shift
|
| 34 |
+
num_train_timesteps = 1000
|
| 35 |
+
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
|
| 36 |
+
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
|
| 37 |
+
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
| 38 |
+
timesteps = sigmas * num_train_timesteps
|
| 39 |
+
return sigmas, timesteps
|
| 40 |
+
|
| 41 |
+
@staticmethod
|
| 42 |
+
def _calculate_shift_qwen_image(image_seq_len, base_seq_len=256, max_seq_len=8192, base_shift=0.5, max_shift=0.9):
|
| 43 |
+
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
| 44 |
+
b = base_shift - m * base_seq_len
|
| 45 |
+
mu = image_seq_len * m + b
|
| 46 |
+
return mu
|
| 47 |
+
|
| 48 |
+
@staticmethod
|
| 49 |
+
def set_timesteps_qwen_image(num_inference_steps=100, denoising_strength=1.0, exponential_shift_mu=None, dynamic_shift_len=None):
|
| 50 |
+
sigma_min = 0.0
|
| 51 |
+
sigma_max = 1.0
|
| 52 |
+
num_train_timesteps = 1000
|
| 53 |
+
shift_terminal = 0.02
|
| 54 |
+
# Sigmas
|
| 55 |
+
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
|
| 56 |
+
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
|
| 57 |
+
# Mu
|
| 58 |
+
if exponential_shift_mu is not None:
|
| 59 |
+
mu = exponential_shift_mu
|
| 60 |
+
elif dynamic_shift_len is not None:
|
| 61 |
+
mu = FlowMatchScheduler._calculate_shift_qwen_image(dynamic_shift_len)
|
| 62 |
+
else:
|
| 63 |
+
mu = 0.8
|
| 64 |
+
sigmas = math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1))
|
| 65 |
+
# Shift terminal
|
| 66 |
+
one_minus_z = 1 - sigmas
|
| 67 |
+
scale_factor = one_minus_z[-1] / (1 - shift_terminal)
|
| 68 |
+
sigmas = 1 - (one_minus_z / scale_factor)
|
| 69 |
+
# Timesteps
|
| 70 |
+
timesteps = sigmas * num_train_timesteps
|
| 71 |
+
return sigmas, timesteps
|
| 72 |
+
|
| 73 |
+
@staticmethod
|
| 74 |
+
def compute_empirical_mu(image_seq_len, num_steps):
|
| 75 |
+
a1, b1 = 8.73809524e-05, 1.89833333
|
| 76 |
+
a2, b2 = 0.00016927, 0.45666666
|
| 77 |
+
|
| 78 |
+
if image_seq_len > 4300:
|
| 79 |
+
mu = a2 * image_seq_len + b2
|
| 80 |
+
return float(mu)
|
| 81 |
+
|
| 82 |
+
m_200 = a2 * image_seq_len + b2
|
| 83 |
+
m_10 = a1 * image_seq_len + b1
|
| 84 |
+
|
| 85 |
+
a = (m_200 - m_10) / 190.0
|
| 86 |
+
b = m_200 - 200.0 * a
|
| 87 |
+
mu = a * num_steps + b
|
| 88 |
+
|
| 89 |
+
return float(mu)
|
| 90 |
+
|
| 91 |
+
@staticmethod
|
| 92 |
+
def set_timesteps_flux2(num_inference_steps=100, denoising_strength=1.0, dynamic_shift_len=1024//16*1024//16):
|
| 93 |
+
sigma_min = 1 / num_inference_steps
|
| 94 |
+
sigma_max = 1.0
|
| 95 |
+
num_train_timesteps = 1000
|
| 96 |
+
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
|
| 97 |
+
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps)
|
| 98 |
+
mu = FlowMatchScheduler.compute_empirical_mu(dynamic_shift_len, num_inference_steps)
|
| 99 |
+
sigmas = math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1))
|
| 100 |
+
timesteps = sigmas * num_train_timesteps
|
| 101 |
+
return sigmas, timesteps
|
| 102 |
+
|
| 103 |
+
@staticmethod
|
| 104 |
+
def set_timesteps_z_image(num_inference_steps=100, denoising_strength=1.0, shift=None, target_timesteps=None):
|
| 105 |
+
sigma_min = 0.0
|
| 106 |
+
sigma_max = 1.0
|
| 107 |
+
shift = 3 if shift is None else shift
|
| 108 |
+
num_train_timesteps = 1000
|
| 109 |
+
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
|
| 110 |
+
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
|
| 111 |
+
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
| 112 |
+
timesteps = sigmas * num_train_timesteps
|
| 113 |
+
if target_timesteps is not None:
|
| 114 |
+
target_timesteps = target_timesteps.to(dtype=timesteps.dtype, device=timesteps.device)
|
| 115 |
+
for timestep in target_timesteps:
|
| 116 |
+
timestep_id = torch.argmin((timesteps - timestep).abs())
|
| 117 |
+
timesteps[timestep_id] = timestep
|
| 118 |
+
return sigmas, timesteps
|
| 119 |
+
|
| 120 |
+
def set_training_weight(self):
|
| 121 |
+
steps = 1000
|
| 122 |
+
x = self.timesteps
|
| 123 |
+
y = torch.exp(-2 * ((x - steps / 2) / steps) ** 2)
|
| 124 |
+
y_shifted = y - y.min()
|
| 125 |
+
bsmntw_weighing = y_shifted * (steps / y_shifted.sum())
|
| 126 |
+
if len(self.timesteps) != 1000:
|
| 127 |
+
# This is an empirical formula.
|
| 128 |
+
bsmntw_weighing = bsmntw_weighing * (len(self.timesteps) / steps)
|
| 129 |
+
bsmntw_weighing = bsmntw_weighing + bsmntw_weighing[1]
|
| 130 |
+
self.linear_timesteps_weights = bsmntw_weighing
|
| 131 |
+
|
| 132 |
+
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, **kwargs):
|
| 133 |
+
self.sigmas, self.timesteps = self.set_timesteps_fn(
|
| 134 |
+
num_inference_steps=num_inference_steps,
|
| 135 |
+
denoising_strength=denoising_strength,
|
| 136 |
+
**kwargs,
|
| 137 |
+
)
|
| 138 |
+
if training:
|
| 139 |
+
self.set_training_weight()
|
| 140 |
+
self.training = True
|
| 141 |
+
else:
|
| 142 |
+
self.training = False
|
| 143 |
+
|
| 144 |
+
def step(self, model_output, timestep, sample, to_final=False, **kwargs):
|
| 145 |
+
if isinstance(timestep, torch.Tensor):
|
| 146 |
+
timestep = timestep.cpu()
|
| 147 |
+
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
| 148 |
+
sigma = self.sigmas[timestep_id]
|
| 149 |
+
if to_final or timestep_id + 1 >= len(self.timesteps):
|
| 150 |
+
sigma_ = 0
|
| 151 |
+
else:
|
| 152 |
+
sigma_ = self.sigmas[timestep_id + 1]
|
| 153 |
+
prev_sample = sample + model_output * (sigma_ - sigma)
|
| 154 |
+
return prev_sample
|
| 155 |
+
|
| 156 |
+
def return_to_timestep(self, timestep, sample, sample_stablized):
|
| 157 |
+
if isinstance(timestep, torch.Tensor):
|
| 158 |
+
timestep = timestep.cpu()
|
| 159 |
+
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
| 160 |
+
sigma = self.sigmas[timestep_id]
|
| 161 |
+
model_output = (sample - sample_stablized) / sigma
|
| 162 |
+
return model_output
|
| 163 |
+
|
| 164 |
+
def add_noise(self, original_samples, noise, timestep):
|
| 165 |
+
if isinstance(timestep, torch.Tensor):
|
| 166 |
+
timestep = timestep.cpu()
|
| 167 |
+
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
| 168 |
+
sigma = self.sigmas[timestep_id]
|
| 169 |
+
sample = (1 - sigma) * original_samples + sigma * noise
|
| 170 |
+
return sample
|
| 171 |
+
|
| 172 |
+
def training_target(self, sample, noise, timestep):
|
| 173 |
+
target = noise - sample
|
| 174 |
+
return target
|
| 175 |
+
|
| 176 |
+
def training_weight(self, timestep):
|
| 177 |
+
timestep_id = torch.argmin((self.timesteps - timestep.to(self.timesteps.device)).abs())
|
| 178 |
+
weights = self.linear_timesteps_weights[timestep_id]
|
| 179 |
+
return weights
|
diffsynth/diffusion/logger.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, torch
|
| 2 |
+
from accelerate import Accelerator
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class ModelLogger:
|
| 6 |
+
def __init__(self, output_path, remove_prefix_in_ckpt=None, state_dict_converter=lambda x:x):
|
| 7 |
+
self.output_path = output_path
|
| 8 |
+
self.remove_prefix_in_ckpt = remove_prefix_in_ckpt
|
| 9 |
+
self.state_dict_converter = state_dict_converter
|
| 10 |
+
self.num_steps = 0
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def on_step_end(self, accelerator: Accelerator, model: torch.nn.Module, save_steps=None):
|
| 14 |
+
self.num_steps += 1
|
| 15 |
+
if save_steps is not None and self.num_steps % save_steps == 0:
|
| 16 |
+
self.save_model(accelerator, model, f"step-{self.num_steps}.safetensors")
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def on_epoch_end(self, accelerator: Accelerator, model: torch.nn.Module, epoch_id):
|
| 20 |
+
accelerator.wait_for_everyone()
|
| 21 |
+
if accelerator.is_main_process:
|
| 22 |
+
state_dict = accelerator.get_state_dict(model)
|
| 23 |
+
state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt)
|
| 24 |
+
state_dict = self.state_dict_converter(state_dict)
|
| 25 |
+
os.makedirs(self.output_path, exist_ok=True)
|
| 26 |
+
path = os.path.join(self.output_path, f"epoch-{epoch_id}.safetensors")
|
| 27 |
+
accelerator.save(state_dict, path, safe_serialization=True)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def on_training_end(self, accelerator: Accelerator, model: torch.nn.Module, save_steps=None):
|
| 31 |
+
if save_steps is not None and self.num_steps % save_steps != 0:
|
| 32 |
+
self.save_model(accelerator, model, f"step-{self.num_steps}.safetensors")
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def save_model(self, accelerator: Accelerator, model: torch.nn.Module, file_name):
|
| 36 |
+
accelerator.wait_for_everyone()
|
| 37 |
+
if accelerator.is_main_process:
|
| 38 |
+
state_dict = accelerator.get_state_dict(model)
|
| 39 |
+
state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt)
|
| 40 |
+
state_dict = self.state_dict_converter(state_dict)
|
| 41 |
+
os.makedirs(self.output_path, exist_ok=True)
|
| 42 |
+
path = os.path.join(self.output_path, file_name)
|
| 43 |
+
accelerator.save(state_dict, path, safe_serialization=True)
|
diffsynth/diffusion/loss.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .base_pipeline import BasePipeline
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def FlowMatchSFTLoss(pipe: BasePipeline, **inputs):
|
| 6 |
+
max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * len(pipe.scheduler.timesteps))
|
| 7 |
+
min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * len(pipe.scheduler.timesteps))
|
| 8 |
+
|
| 9 |
+
timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,))
|
| 10 |
+
timestep = pipe.scheduler.timesteps[timestep_id].to(dtype=pipe.torch_dtype, device=pipe.device)
|
| 11 |
+
|
| 12 |
+
noise = torch.randn_like(inputs["input_latents"])
|
| 13 |
+
inputs["latents"] = pipe.scheduler.add_noise(inputs["input_latents"], noise, timestep)
|
| 14 |
+
training_target = pipe.scheduler.training_target(inputs["input_latents"], noise, timestep)
|
| 15 |
+
|
| 16 |
+
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
|
| 17 |
+
noise_pred = pipe.model_fn(**models, **inputs, timestep=timestep)
|
| 18 |
+
|
| 19 |
+
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
|
| 20 |
+
loss = loss * pipe.scheduler.training_weight(timestep)
|
| 21 |
+
return loss
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def DirectDistillLoss(pipe: BasePipeline, **inputs):
|
| 25 |
+
pipe.scheduler.set_timesteps(inputs["num_inference_steps"])
|
| 26 |
+
pipe.scheduler.training = True
|
| 27 |
+
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
|
| 28 |
+
for progress_id, timestep in enumerate(pipe.scheduler.timesteps):
|
| 29 |
+
timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device)
|
| 30 |
+
noise_pred = pipe.model_fn(**models, **inputs, timestep=timestep, progress_id=progress_id)
|
| 31 |
+
inputs["latents"] = pipe.step(pipe.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs)
|
| 32 |
+
loss = torch.nn.functional.mse_loss(inputs["latents"].float(), inputs["input_latents"].float())
|
| 33 |
+
return loss
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class TrajectoryImitationLoss(torch.nn.Module):
|
| 37 |
+
def __init__(self):
|
| 38 |
+
super().__init__()
|
| 39 |
+
self.initialized = False
|
| 40 |
+
|
| 41 |
+
def initialize(self, device):
|
| 42 |
+
import lpips # TODO: remove it
|
| 43 |
+
self.loss_fn = lpips.LPIPS(net='alex').to(device)
|
| 44 |
+
self.initialized = True
|
| 45 |
+
|
| 46 |
+
def fetch_trajectory(self, pipe: BasePipeline, timesteps_student, inputs_shared, inputs_posi, inputs_nega, num_inference_steps, cfg_scale):
|
| 47 |
+
trajectory = [inputs_shared["latents"].clone()]
|
| 48 |
+
|
| 49 |
+
pipe.scheduler.set_timesteps(num_inference_steps, target_timesteps=timesteps_student)
|
| 50 |
+
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
|
| 51 |
+
for progress_id, timestep in enumerate(pipe.scheduler.timesteps):
|
| 52 |
+
timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device)
|
| 53 |
+
noise_pred = pipe.cfg_guided_model_fn(
|
| 54 |
+
pipe.model_fn, cfg_scale,
|
| 55 |
+
inputs_shared, inputs_posi, inputs_nega,
|
| 56 |
+
**models, timestep=timestep, progress_id=progress_id
|
| 57 |
+
)
|
| 58 |
+
inputs_shared["latents"] = pipe.step(pipe.scheduler, progress_id=progress_id, noise_pred=noise_pred.detach(), **inputs_shared)
|
| 59 |
+
|
| 60 |
+
trajectory.append(inputs_shared["latents"].clone())
|
| 61 |
+
return pipe.scheduler.timesteps, trajectory
|
| 62 |
+
|
| 63 |
+
def align_trajectory(self, pipe: BasePipeline, timesteps_teacher, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, num_inference_steps, cfg_scale):
|
| 64 |
+
loss = 0
|
| 65 |
+
pipe.scheduler.set_timesteps(num_inference_steps, training=True)
|
| 66 |
+
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
|
| 67 |
+
for progress_id, timestep in enumerate(pipe.scheduler.timesteps):
|
| 68 |
+
timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device)
|
| 69 |
+
|
| 70 |
+
progress_id_teacher = torch.argmin((timesteps_teacher - timestep).abs())
|
| 71 |
+
inputs_shared["latents"] = trajectory_teacher[progress_id_teacher]
|
| 72 |
+
|
| 73 |
+
noise_pred = pipe.cfg_guided_model_fn(
|
| 74 |
+
pipe.model_fn, cfg_scale,
|
| 75 |
+
inputs_shared, inputs_posi, inputs_nega,
|
| 76 |
+
**models, timestep=timestep, progress_id=progress_id
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
sigma = pipe.scheduler.sigmas[progress_id]
|
| 80 |
+
sigma_ = 0 if progress_id + 1 >= len(pipe.scheduler.timesteps) else pipe.scheduler.sigmas[progress_id + 1]
|
| 81 |
+
if progress_id + 1 >= len(pipe.scheduler.timesteps):
|
| 82 |
+
latents_ = trajectory_teacher[-1]
|
| 83 |
+
else:
|
| 84 |
+
progress_id_teacher = torch.argmin((timesteps_teacher - pipe.scheduler.timesteps[progress_id + 1]).abs())
|
| 85 |
+
latents_ = trajectory_teacher[progress_id_teacher]
|
| 86 |
+
|
| 87 |
+
target = (latents_ - inputs_shared["latents"]) / (sigma_ - sigma)
|
| 88 |
+
loss = loss + torch.nn.functional.mse_loss(noise_pred.float(), target.float()) * pipe.scheduler.training_weight(timestep)
|
| 89 |
+
return loss
|
| 90 |
+
|
| 91 |
+
def compute_regularization(self, pipe: BasePipeline, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, num_inference_steps, cfg_scale):
|
| 92 |
+
inputs_shared["latents"] = trajectory_teacher[0]
|
| 93 |
+
pipe.scheduler.set_timesteps(num_inference_steps)
|
| 94 |
+
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
|
| 95 |
+
for progress_id, timestep in enumerate(pipe.scheduler.timesteps):
|
| 96 |
+
timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device)
|
| 97 |
+
noise_pred = pipe.cfg_guided_model_fn(
|
| 98 |
+
pipe.model_fn, cfg_scale,
|
| 99 |
+
inputs_shared, inputs_posi, inputs_nega,
|
| 100 |
+
**models, timestep=timestep, progress_id=progress_id
|
| 101 |
+
)
|
| 102 |
+
inputs_shared["latents"] = pipe.step(pipe.scheduler, progress_id=progress_id, noise_pred=noise_pred.detach(), **inputs_shared)
|
| 103 |
+
|
| 104 |
+
image_pred = pipe.vae_decoder(inputs_shared["latents"])
|
| 105 |
+
image_real = pipe.vae_decoder(trajectory_teacher[-1])
|
| 106 |
+
loss = self.loss_fn(image_pred.float(), image_real.float())
|
| 107 |
+
return loss
|
| 108 |
+
|
| 109 |
+
def forward(self, pipe: BasePipeline, inputs_shared, inputs_posi, inputs_nega):
|
| 110 |
+
if not self.initialized:
|
| 111 |
+
self.initialize(pipe.device)
|
| 112 |
+
with torch.no_grad():
|
| 113 |
+
pipe.scheduler.set_timesteps(8)
|
| 114 |
+
timesteps_teacher, trajectory_teacher = self.fetch_trajectory(inputs_shared["teacher"], pipe.scheduler.timesteps, inputs_shared, inputs_posi, inputs_nega, 50, 2)
|
| 115 |
+
timesteps_teacher = timesteps_teacher.to(dtype=pipe.torch_dtype, device=pipe.device)
|
| 116 |
+
loss_1 = self.align_trajectory(pipe, timesteps_teacher, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, 8, 1)
|
| 117 |
+
loss_2 = self.compute_regularization(pipe, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, 8, 1)
|
| 118 |
+
loss = loss_1 + loss_2
|
| 119 |
+
return loss
|
diffsynth/diffusion/parsers.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def add_dataset_base_config(parser: argparse.ArgumentParser):
|
| 5 |
+
parser.add_argument("--dataset_base_path", type=str, default="", required=True, help="Base path of the dataset.")
|
| 6 |
+
parser.add_argument("--dataset_metadata_path", type=str, default=None, help="Path to the metadata file of the dataset.")
|
| 7 |
+
parser.add_argument("--dataset_repeat", type=int, default=1, help="Number of times to repeat the dataset per epoch.")
|
| 8 |
+
parser.add_argument("--dataset_num_workers", type=int, default=0, help="Number of workers for data loading.")
|
| 9 |
+
parser.add_argument("--data_file_keys", type=str, default="image,video", help="Data file keys in the metadata. Comma-separated.")
|
| 10 |
+
return parser
|
| 11 |
+
|
| 12 |
+
def add_image_size_config(parser: argparse.ArgumentParser):
|
| 13 |
+
parser.add_argument("--height", type=int, default=None, help="Height of images. Leave `height` and `width` empty to enable dynamic resolution.")
|
| 14 |
+
parser.add_argument("--width", type=int, default=None, help="Width of images. Leave `height` and `width` empty to enable dynamic resolution.")
|
| 15 |
+
parser.add_argument("--max_pixels", type=int, default=1024*1024, help="Maximum number of pixels per frame, used for dynamic resolution.")
|
| 16 |
+
return parser
|
| 17 |
+
|
| 18 |
+
def add_video_size_config(parser: argparse.ArgumentParser):
|
| 19 |
+
parser.add_argument("--height", type=int, default=None, help="Height of images. Leave `height` and `width` empty to enable dynamic resolution.")
|
| 20 |
+
parser.add_argument("--width", type=int, default=None, help="Width of images. Leave `height` and `width` empty to enable dynamic resolution.")
|
| 21 |
+
parser.add_argument("--max_pixels", type=int, default=1024*1024, help="Maximum number of pixels per frame, used for dynamic resolution.")
|
| 22 |
+
parser.add_argument("--num_frames", type=int, default=81, help="Number of frames per video. Frames are sampled from the video prefix.")
|
| 23 |
+
return parser
|
| 24 |
+
|
| 25 |
+
def add_model_config(parser: argparse.ArgumentParser):
|
| 26 |
+
parser.add_argument("--model_paths", type=str, default=None, help="Paths to load models. In JSON format.")
|
| 27 |
+
parser.add_argument("--model_id_with_origin_paths", type=str, default=None, help="Model ID with origin paths, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Comma-separated.")
|
| 28 |
+
parser.add_argument("--extra_inputs", default=None, help="Additional model inputs, comma-separated.")
|
| 29 |
+
parser.add_argument("--fp8_models", default=None, help="Models with FP8 precision, comma-separated.")
|
| 30 |
+
parser.add_argument("--offload_models", default=None, help="Models with offload, comma-separated. Only used in splited training.")
|
| 31 |
+
return parser
|
| 32 |
+
|
| 33 |
+
def add_training_config(parser: argparse.ArgumentParser):
|
| 34 |
+
parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate.")
|
| 35 |
+
parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs.")
|
| 36 |
+
parser.add_argument("--trainable_models", type=str, default=None, help="Models to train, e.g., dit, vae, text_encoder.")
|
| 37 |
+
parser.add_argument("--find_unused_parameters", default=False, action="store_true", help="Whether to find unused parameters in DDP.")
|
| 38 |
+
parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay.")
|
| 39 |
+
parser.add_argument("--task", type=str, default="sft", required=False, help="Task type.")
|
| 40 |
+
return parser
|
| 41 |
+
|
| 42 |
+
def add_output_config(parser: argparse.ArgumentParser):
|
| 43 |
+
parser.add_argument("--output_path", type=str, default="./models", help="Output save path.")
|
| 44 |
+
parser.add_argument("--remove_prefix_in_ckpt", type=str, default="pipe.dit.", help="Remove prefix in ckpt.")
|
| 45 |
+
parser.add_argument("--save_steps", type=int, default=None, help="Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.")
|
| 46 |
+
return parser
|
| 47 |
+
|
| 48 |
+
def add_lora_config(parser: argparse.ArgumentParser):
|
| 49 |
+
parser.add_argument("--lora_base_model", type=str, default=None, help="Which model LoRA is added to.")
|
| 50 |
+
parser.add_argument("--lora_target_modules", type=str, default="q,k,v,o,ffn.0,ffn.2", help="Which layers LoRA is added to.")
|
| 51 |
+
parser.add_argument("--lora_rank", type=int, default=32, help="Rank of LoRA.")
|
| 52 |
+
parser.add_argument("--lora_checkpoint", type=str, default=None, help="Path to the LoRA checkpoint. If provided, LoRA will be loaded from this checkpoint.")
|
| 53 |
+
parser.add_argument("--preset_lora_path", type=str, default=None, help="Path to the preset LoRA checkpoint. If provided, this LoRA will be fused to the base model.")
|
| 54 |
+
parser.add_argument("--preset_lora_model", type=str, default=None, help="Which model the preset LoRA is fused to.")
|
| 55 |
+
return parser
|
| 56 |
+
|
| 57 |
+
def add_gradient_config(parser: argparse.ArgumentParser):
|
| 58 |
+
parser.add_argument("--use_gradient_checkpointing", default=False, action="store_true", help="Whether to use gradient checkpointing.")
|
| 59 |
+
parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.")
|
| 60 |
+
parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.")
|
| 61 |
+
return parser
|
| 62 |
+
|
| 63 |
+
def add_general_config(parser: argparse.ArgumentParser):
|
| 64 |
+
parser = add_dataset_base_config(parser)
|
| 65 |
+
parser = add_model_config(parser)
|
| 66 |
+
parser = add_training_config(parser)
|
| 67 |
+
parser = add_output_config(parser)
|
| 68 |
+
parser = add_lora_config(parser)
|
| 69 |
+
parser = add_gradient_config(parser)
|
| 70 |
+
return parser
|
diffsynth/diffusion/runner.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, torch
|
| 2 |
+
import wandb # 新增
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
from accelerate import Accelerator
|
| 5 |
+
from .training_module import DiffusionTrainingModule
|
| 6 |
+
from .logger import ModelLogger
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def launch_training_task(
|
| 10 |
+
accelerator: Accelerator,
|
| 11 |
+
dataset: torch.utils.data.Dataset,
|
| 12 |
+
model: DiffusionTrainingModule,
|
| 13 |
+
model_logger: ModelLogger,
|
| 14 |
+
learning_rate: float = 1e-5,
|
| 15 |
+
weight_decay: float = 1e-2,
|
| 16 |
+
num_workers: int = 1,
|
| 17 |
+
save_steps: int = None,
|
| 18 |
+
num_epochs: int = 1,
|
| 19 |
+
args = None,
|
| 20 |
+
):
|
| 21 |
+
if args is not None:
|
| 22 |
+
learning_rate = args.learning_rate
|
| 23 |
+
weight_decay = args.weight_decay
|
| 24 |
+
num_workers = args.dataset_num_workers
|
| 25 |
+
save_steps = args.save_steps
|
| 26 |
+
num_epochs = args.num_epochs
|
| 27 |
+
|
| 28 |
+
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate, weight_decay=weight_decay)
|
| 29 |
+
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
|
| 30 |
+
dataloader = torch.utils.data.DataLoader(
|
| 31 |
+
dataset,
|
| 32 |
+
shuffle=True,
|
| 33 |
+
collate_fn=lambda x: x[0],
|
| 34 |
+
num_workers=num_workers,
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
model, optimizer, dataloader, scheduler = accelerator.prepare(
|
| 38 |
+
model, optimizer, dataloader, scheduler
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
global_step = 0 # 用于 wandb 记录全局 step
|
| 42 |
+
|
| 43 |
+
for epoch_id in range(num_epochs):
|
| 44 |
+
# 只在本地主进程显示 tqdm,避免多卡重复进度条
|
| 45 |
+
pbar = tqdm(
|
| 46 |
+
dataloader,
|
| 47 |
+
disable=not accelerator.is_local_main_process,
|
| 48 |
+
desc=f"Epoch {epoch_id}",
|
| 49 |
+
)
|
| 50 |
+
for data in pbar:
|
| 51 |
+
with accelerator.accumulate(model):
|
| 52 |
+
optimizer.zero_grad()
|
| 53 |
+
if dataset.load_from_cache:
|
| 54 |
+
loss = model({}, inputs=data)
|
| 55 |
+
else:
|
| 56 |
+
loss = model(data)
|
| 57 |
+
accelerator.backward(loss)
|
| 58 |
+
optimizer.step()
|
| 59 |
+
model_logger.on_step_end(accelerator, model, save_steps)
|
| 60 |
+
scheduler.step()
|
| 61 |
+
|
| 62 |
+
global_step += 1
|
| 63 |
+
|
| 64 |
+
# ============= wandb logging(只在主进程) =============
|
| 65 |
+
if (
|
| 66 |
+
args is not None
|
| 67 |
+
and hasattr(args, "wandb_mode")
|
| 68 |
+
and args.wandb_mode != "disabled"
|
| 69 |
+
and accelerator.is_main_process
|
| 70 |
+
):
|
| 71 |
+
log_every = getattr(args, "wandb_log_every", 10)
|
| 72 |
+
if global_step % log_every == 0:
|
| 73 |
+
# 这里直接用当前进程的 loss 就够了
|
| 74 |
+
loss_value = loss.detach().float().item()
|
| 75 |
+
try:
|
| 76 |
+
lr = scheduler.get_last_lr()[0]
|
| 77 |
+
except Exception:
|
| 78 |
+
lr = optimizer.param_groups[0]["lr"]
|
| 79 |
+
|
| 80 |
+
wandb.log(
|
| 81 |
+
{
|
| 82 |
+
"train/loss": loss_value,
|
| 83 |
+
"train/lr": lr,
|
| 84 |
+
"train/epoch": epoch_id,
|
| 85 |
+
"train/step": global_step,
|
| 86 |
+
}
|
| 87 |
+
)
|
| 88 |
+
# =======================================================
|
| 89 |
+
|
| 90 |
+
if save_steps is None:
|
| 91 |
+
model_logger.on_epoch_end(accelerator, model, epoch_id)
|
| 92 |
+
model_logger.on_training_end(accelerator, model, save_steps)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def launch_data_process_task(
|
| 96 |
+
accelerator: Accelerator,
|
| 97 |
+
dataset: torch.utils.data.Dataset,
|
| 98 |
+
model: DiffusionTrainingModule,
|
| 99 |
+
model_logger: ModelLogger,
|
| 100 |
+
num_workers: int = 8,
|
| 101 |
+
args = None,
|
| 102 |
+
):
|
| 103 |
+
if args is not None:
|
| 104 |
+
num_workers = args.dataset_num_workers
|
| 105 |
+
|
| 106 |
+
dataloader = torch.utils.data.DataLoader(
|
| 107 |
+
dataset,
|
| 108 |
+
shuffle=False,
|
| 109 |
+
collate_fn=lambda x: x[0],
|
| 110 |
+
num_workers=num_workers,
|
| 111 |
+
)
|
| 112 |
+
model, dataloader = accelerator.prepare(model, dataloader)
|
| 113 |
+
|
| 114 |
+
for data_id, data in enumerate(tqdm(
|
| 115 |
+
dataloader,
|
| 116 |
+
disable=not accelerator.is_local_main_process,
|
| 117 |
+
desc="Data process",
|
| 118 |
+
)):
|
| 119 |
+
with accelerator.accumulate(model):
|
| 120 |
+
with torch.no_grad():
|
| 121 |
+
folder = os.path.join(model_logger.output_path, str(accelerator.process_index))
|
| 122 |
+
os.makedirs(folder, exist_ok=True)
|
| 123 |
+
save_path = os.path.join(
|
| 124 |
+
model_logger.output_path,
|
| 125 |
+
str(accelerator.process_index),
|
| 126 |
+
f"{data_id}.pth",
|
| 127 |
+
)
|
| 128 |
+
data = model(data)
|
| 129 |
+
torch.save(data, save_path)
|
diffsynth/diffusion/training_module.py
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch, json
|
| 2 |
+
from ..core import ModelConfig, load_state_dict
|
| 3 |
+
from ..utils.controlnet import ControlNetInput
|
| 4 |
+
from peft import LoraConfig, inject_adapter_in_model
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class DiffusionTrainingModule(torch.nn.Module):
|
| 8 |
+
def __init__(self):
|
| 9 |
+
super().__init__()
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def to(self, *args, **kwargs):
|
| 13 |
+
for name, model in self.named_children():
|
| 14 |
+
model.to(*args, **kwargs)
|
| 15 |
+
return self
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def trainable_modules(self):
|
| 19 |
+
trainable_modules = filter(lambda p: p.requires_grad, self.parameters())
|
| 20 |
+
return trainable_modules
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def trainable_param_names(self):
|
| 24 |
+
trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.named_parameters()))
|
| 25 |
+
trainable_param_names = set([named_param[0] for named_param in trainable_param_names])
|
| 26 |
+
return trainable_param_names
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def add_lora_to_model(self, model, target_modules, lora_rank, lora_alpha=None, upcast_dtype=None):
|
| 30 |
+
if lora_alpha is None:
|
| 31 |
+
lora_alpha = lora_rank
|
| 32 |
+
if isinstance(target_modules, list) and len(target_modules) == 1:
|
| 33 |
+
target_modules = target_modules[0]
|
| 34 |
+
lora_config = LoraConfig(r=lora_rank, lora_alpha=lora_alpha, target_modules=target_modules)
|
| 35 |
+
model = inject_adapter_in_model(lora_config, model)
|
| 36 |
+
if upcast_dtype is not None:
|
| 37 |
+
for param in model.parameters():
|
| 38 |
+
if param.requires_grad:
|
| 39 |
+
param.data = param.to(upcast_dtype)
|
| 40 |
+
return model
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def mapping_lora_state_dict(self, state_dict):
|
| 44 |
+
new_state_dict = {}
|
| 45 |
+
for key, value in state_dict.items():
|
| 46 |
+
if "lora_A.weight" in key or "lora_B.weight" in key:
|
| 47 |
+
new_key = key.replace("lora_A.weight", "lora_A.default.weight").replace("lora_B.weight", "lora_B.default.weight")
|
| 48 |
+
new_state_dict[new_key] = value
|
| 49 |
+
elif "lora_A.default.weight" in key or "lora_B.default.weight" in key:
|
| 50 |
+
new_state_dict[key] = value
|
| 51 |
+
return new_state_dict
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def export_trainable_state_dict(self, state_dict, remove_prefix=None):
|
| 55 |
+
trainable_param_names = self.trainable_param_names()
|
| 56 |
+
state_dict = {name: param for name, param in state_dict.items() if name in trainable_param_names}
|
| 57 |
+
if remove_prefix is not None:
|
| 58 |
+
state_dict_ = {}
|
| 59 |
+
for name, param in state_dict.items():
|
| 60 |
+
if name.startswith(remove_prefix):
|
| 61 |
+
name = name[len(remove_prefix):]
|
| 62 |
+
state_dict_[name] = param
|
| 63 |
+
state_dict = state_dict_
|
| 64 |
+
return state_dict
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def transfer_data_to_device(self, data, device, torch_float_dtype=None):
|
| 68 |
+
if data is None:
|
| 69 |
+
return data
|
| 70 |
+
elif isinstance(data, torch.Tensor):
|
| 71 |
+
data = data.to(device)
|
| 72 |
+
if torch_float_dtype is not None and data.dtype in [torch.float, torch.float16, torch.bfloat16]:
|
| 73 |
+
data = data.to(torch_float_dtype)
|
| 74 |
+
return data
|
| 75 |
+
elif isinstance(data, tuple):
|
| 76 |
+
data = tuple(self.transfer_data_to_device(x, device, torch_float_dtype) for x in data)
|
| 77 |
+
return data
|
| 78 |
+
elif isinstance(data, list):
|
| 79 |
+
data = list(self.transfer_data_to_device(x, device, torch_float_dtype) for x in data)
|
| 80 |
+
return data
|
| 81 |
+
elif isinstance(data, dict):
|
| 82 |
+
data = {i: self.transfer_data_to_device(data[i], device, torch_float_dtype) for i in data}
|
| 83 |
+
return data
|
| 84 |
+
else:
|
| 85 |
+
return data
|
| 86 |
+
|
| 87 |
+
def parse_vram_config(self, fp8=False, offload=False, device="cpu"):
|
| 88 |
+
if fp8:
|
| 89 |
+
return {
|
| 90 |
+
"offload_dtype": torch.float8_e4m3fn,
|
| 91 |
+
"offload_device": device,
|
| 92 |
+
"onload_dtype": torch.float8_e4m3fn,
|
| 93 |
+
"onload_device": device,
|
| 94 |
+
"preparing_dtype": torch.float8_e4m3fn,
|
| 95 |
+
"preparing_device": device,
|
| 96 |
+
"computation_dtype": torch.bfloat16,
|
| 97 |
+
"computation_device": device,
|
| 98 |
+
}
|
| 99 |
+
elif offload:
|
| 100 |
+
return {
|
| 101 |
+
"offload_dtype": "disk",
|
| 102 |
+
"offload_device": "disk",
|
| 103 |
+
"onload_dtype": "disk",
|
| 104 |
+
"onload_device": "disk",
|
| 105 |
+
"preparing_dtype": torch.bfloat16,
|
| 106 |
+
"preparing_device": device,
|
| 107 |
+
"computation_dtype": torch.bfloat16,
|
| 108 |
+
"computation_device": device,
|
| 109 |
+
"clear_parameters": True,
|
| 110 |
+
}
|
| 111 |
+
else:
|
| 112 |
+
return {}
|
| 113 |
+
|
| 114 |
+
def parse_model_configs(self, model_paths, model_id_with_origin_paths, fp8_models=None, offload_models=None, device="cpu"):
|
| 115 |
+
fp8_models = [] if fp8_models is None else fp8_models.split(",")
|
| 116 |
+
offload_models = [] if offload_models is None else offload_models.split(",")
|
| 117 |
+
model_configs = []
|
| 118 |
+
if model_paths is not None:
|
| 119 |
+
model_paths = json.loads(model_paths)
|
| 120 |
+
for path in model_paths:
|
| 121 |
+
vram_config = self.parse_vram_config(
|
| 122 |
+
fp8=path in fp8_models,
|
| 123 |
+
offload=path in offload_models,
|
| 124 |
+
device=device
|
| 125 |
+
)
|
| 126 |
+
model_configs.append(ModelConfig(path=path, **vram_config))
|
| 127 |
+
if model_id_with_origin_paths is not None:
|
| 128 |
+
model_id_with_origin_paths = model_id_with_origin_paths.split(",")
|
| 129 |
+
for model_id_with_origin_path in model_id_with_origin_paths:
|
| 130 |
+
model_id, origin_file_pattern = model_id_with_origin_path.split(":")
|
| 131 |
+
vram_config = self.parse_vram_config(
|
| 132 |
+
fp8=model_id_with_origin_path in fp8_models,
|
| 133 |
+
offload=model_id_with_origin_path in offload_models,
|
| 134 |
+
device=device
|
| 135 |
+
)
|
| 136 |
+
model_configs.append(ModelConfig(model_id=model_id, origin_file_pattern=origin_file_pattern, **vram_config))
|
| 137 |
+
return model_configs
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def switch_pipe_to_training_mode(
|
| 141 |
+
self,
|
| 142 |
+
pipe,
|
| 143 |
+
trainable_models=None,
|
| 144 |
+
lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None,
|
| 145 |
+
preset_lora_path=None, preset_lora_model=None,
|
| 146 |
+
task="sft",
|
| 147 |
+
):
|
| 148 |
+
# Scheduler
|
| 149 |
+
pipe.scheduler.set_timesteps(1000, training=True)
|
| 150 |
+
|
| 151 |
+
# Freeze untrainable models
|
| 152 |
+
pipe.freeze_except([] if trainable_models is None else trainable_models.split(","))
|
| 153 |
+
|
| 154 |
+
# Preset LoRA
|
| 155 |
+
if preset_lora_path is not None:
|
| 156 |
+
pipe.load_lora(getattr(pipe, preset_lora_model), preset_lora_path)
|
| 157 |
+
|
| 158 |
+
# FP8
|
| 159 |
+
# FP8 relies on a model-specific memory management scheme.
|
| 160 |
+
# It is delegated to the subclass.
|
| 161 |
+
|
| 162 |
+
# Add LoRA to the base models
|
| 163 |
+
if lora_base_model is not None and not task.endswith(":data_process"):
|
| 164 |
+
if (not hasattr(pipe, lora_base_model)) or getattr(pipe, lora_base_model) is None:
|
| 165 |
+
print(f"No {lora_base_model} models in the pipeline. We cannot patch LoRA on the model. If this occurs during the data processing stage, it is normal.")
|
| 166 |
+
return
|
| 167 |
+
model = self.add_lora_to_model(
|
| 168 |
+
getattr(pipe, lora_base_model),
|
| 169 |
+
target_modules=lora_target_modules.split(","),
|
| 170 |
+
lora_rank=lora_rank,
|
| 171 |
+
upcast_dtype=pipe.torch_dtype,
|
| 172 |
+
)
|
| 173 |
+
if lora_checkpoint is not None:
|
| 174 |
+
state_dict = load_state_dict(lora_checkpoint)
|
| 175 |
+
state_dict = self.mapping_lora_state_dict(state_dict)
|
| 176 |
+
load_result = model.load_state_dict(state_dict, strict=False)
|
| 177 |
+
print(f"LoRA checkpoint loaded: {lora_checkpoint}, total {len(state_dict)} keys")
|
| 178 |
+
if len(load_result[1]) > 0:
|
| 179 |
+
print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}")
|
| 180 |
+
setattr(pipe, lora_base_model, model)
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def split_pipeline_units(self, task, pipe, trainable_models=None, lora_base_model=None):
|
| 184 |
+
models_require_backward = []
|
| 185 |
+
if trainable_models is not None:
|
| 186 |
+
models_require_backward += trainable_models.split(",")
|
| 187 |
+
if lora_base_model is not None:
|
| 188 |
+
models_require_backward += [lora_base_model]
|
| 189 |
+
if task.endswith(":data_process"):
|
| 190 |
+
_, pipe.units = pipe.split_pipeline_units(models_require_backward)
|
| 191 |
+
elif task.endswith(":train"):
|
| 192 |
+
pipe.units, _ = pipe.split_pipeline_units(models_require_backward)
|
| 193 |
+
return pipe
|
| 194 |
+
|
| 195 |
+
def parse_extra_inputs(self, data, extra_inputs, inputs_shared):
|
| 196 |
+
controlnet_keys_map = (
|
| 197 |
+
("blockwise_controlnet_", "blockwise_controlnet_inputs",),
|
| 198 |
+
("controlnet_", "controlnet_inputs"),
|
| 199 |
+
)
|
| 200 |
+
controlnet_inputs = {}
|
| 201 |
+
for extra_input in extra_inputs:
|
| 202 |
+
for prefix, name in controlnet_keys_map:
|
| 203 |
+
if extra_input.startswith(prefix):
|
| 204 |
+
if name not in controlnet_inputs:
|
| 205 |
+
controlnet_inputs[name] = {}
|
| 206 |
+
controlnet_inputs[name][extra_input.replace(prefix, "")] = data[extra_input]
|
| 207 |
+
break
|
| 208 |
+
else:
|
| 209 |
+
inputs_shared[extra_input] = data[extra_input]
|
| 210 |
+
for name, params in controlnet_inputs.items():
|
| 211 |
+
inputs_shared[name] = [ControlNetInput(**params)]
|
| 212 |
+
return inputs_shared
|
diffsynth/models/comp_attn_model.py
ADDED
|
@@ -0,0 +1,592 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from typing import Optional, Sequence
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
from ..diffusion.base_pipeline import PipelineUnit
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dataclass
|
| 12 |
+
class CompAttnConfig:
|
| 13 |
+
subjects: Sequence[str]
|
| 14 |
+
bboxes: Optional[Sequence] = None
|
| 15 |
+
enable_sci: bool = True
|
| 16 |
+
enable_lam: bool = True
|
| 17 |
+
temperature: float = 0.2
|
| 18 |
+
apply_to_negative: bool = False
|
| 19 |
+
interpolate: bool = False
|
| 20 |
+
state_texts: Optional[Sequence[Sequence[str]]] = None
|
| 21 |
+
state_weights: Optional[Sequence] = None
|
| 22 |
+
state_scale: float = 1.0
|
| 23 |
+
state_template: str = "{subject} is {state}"
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def find_subsequence_indices(prompt_ids: torch.Tensor, subject_ids: torch.Tensor, valid_len: int) -> list[int]:
|
| 27 |
+
if subject_ids.numel() == 0 or valid_len <= 0:
|
| 28 |
+
return []
|
| 29 |
+
prompt_slice = prompt_ids[:valid_len].tolist()
|
| 30 |
+
subject_list = subject_ids.tolist()
|
| 31 |
+
span = len(subject_list)
|
| 32 |
+
if span > valid_len:
|
| 33 |
+
return []
|
| 34 |
+
for start in range(valid_len - span + 1):
|
| 35 |
+
if prompt_slice[start:start + span] == subject_list:
|
| 36 |
+
return list(range(start, start + span))
|
| 37 |
+
return []
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def build_subject_token_mask(indices_list: list[list[int]], seq_len: int) -> torch.Tensor:
|
| 41 |
+
mask = torch.zeros((len(indices_list), seq_len), dtype=torch.bool)
|
| 42 |
+
for i, indices in enumerate(indices_list):
|
| 43 |
+
if not indices:
|
| 44 |
+
continue
|
| 45 |
+
mask[i, torch.tensor(indices, dtype=torch.long)] = True
|
| 46 |
+
return mask
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def compute_saliency(prompt_vecs: torch.Tensor, anchor_vecs: torch.Tensor, tau: float) -> torch.Tensor:
|
| 50 |
+
prompt_norm = prompt_vecs / (prompt_vecs.norm(dim=-1, keepdim=True) + 1e-8)
|
| 51 |
+
anchor_norm = anchor_vecs / (anchor_vecs.norm(dim=-1, keepdim=True) + 1e-8)
|
| 52 |
+
cosine = torch.matmul(prompt_norm, anchor_norm.transpose(0, 1))
|
| 53 |
+
scores = torch.exp(cosine / tau)
|
| 54 |
+
diag = scores.diagonal()
|
| 55 |
+
denom = scores.sum(dim=1).clamp(min=1e-8)
|
| 56 |
+
return diag / denom
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def compute_delta(anchor_vecs: torch.Tensor) -> torch.Tensor:
|
| 60 |
+
total = anchor_vecs.sum(dim=0, keepdim=True)
|
| 61 |
+
return anchor_vecs * anchor_vecs.shape[0] - total
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
_sci_call_count = [0] # 使用列表以便在函数内修改
|
| 65 |
+
|
| 66 |
+
def apply_sci(context: torch.Tensor, state: dict, timestep: torch.Tensor) -> torch.Tensor:
|
| 67 |
+
if state is None or not state.get("enable_sci", False):
|
| 68 |
+
return context
|
| 69 |
+
subject_mask = state.get("subject_token_mask")
|
| 70 |
+
delta = state.get("delta")
|
| 71 |
+
saliency = state.get("saliency")
|
| 72 |
+
if subject_mask is None or delta is None or saliency is None:
|
| 73 |
+
return context
|
| 74 |
+
if subject_mask.numel() == 0:
|
| 75 |
+
return context
|
| 76 |
+
t_scale = float(state.get("timestep_scale", 1000.0))
|
| 77 |
+
t_value = float(timestep.reshape(-1)[0].item())
|
| 78 |
+
t_ratio = max(0.0, min(1.0, t_value / t_scale))
|
| 79 |
+
omega = 1.0 - t_ratio
|
| 80 |
+
delta = delta.to(device=context.device, dtype=context.dtype)
|
| 81 |
+
saliency = saliency.to(device=context.device, dtype=context.dtype)
|
| 82 |
+
scale = omega * (1.0 - saliency).unsqueeze(-1)
|
| 83 |
+
delta = delta * scale
|
| 84 |
+
mask = subject_mask.to(device=context.device)
|
| 85 |
+
token_delta = torch.matmul(mask.to(dtype=context.dtype).transpose(0, 1), delta)
|
| 86 |
+
apply_mask = state.get("apply_mask")
|
| 87 |
+
if apply_mask is not None:
|
| 88 |
+
apply_mask = apply_mask.to(device=context.device, dtype=context.dtype).view(-1, 1, 1)
|
| 89 |
+
else:
|
| 90 |
+
apply_mask = 1.0
|
| 91 |
+
|
| 92 |
+
# ========== DEBUG: 打印 SCI 信息 ==========
|
| 93 |
+
_sci_call_count[0] += 1
|
| 94 |
+
if _sci_call_count[0] % 100 == 1:
|
| 95 |
+
print(f"\n{'='*60}")
|
| 96 |
+
print(f"[SCI (Saliency-Controlled Intervention) #{_sci_call_count[0]}]")
|
| 97 |
+
print(f" timestep: {t_value:.2f}, t_ratio: {t_ratio:.4f}, omega: {omega:.4f}")
|
| 98 |
+
print(f" saliency per subject: {saliency.tolist()}")
|
| 99 |
+
print(f" delta shape: {delta.shape}")
|
| 100 |
+
print(f" delta norm per subject: {delta.norm(dim=-1).tolist()}")
|
| 101 |
+
print(f" token_delta shape: {token_delta.shape}")
|
| 102 |
+
print(f" context modification norm: {(token_delta.unsqueeze(0) * apply_mask).norm().item():.6f}")
|
| 103 |
+
print(f"{'='*60}\n")
|
| 104 |
+
|
| 105 |
+
return context + token_delta.unsqueeze(0) * apply_mask
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def interpolate_bboxes(bboxes: torch.Tensor, target_frames: int) -> torch.Tensor:
|
| 109 |
+
if bboxes.shape[2] == target_frames:
|
| 110 |
+
return bboxes
|
| 111 |
+
b, m, f, _ = bboxes.shape
|
| 112 |
+
coords = bboxes.reshape(b * m, f, 4).transpose(1, 2)
|
| 113 |
+
coords = F.interpolate(coords, size=target_frames, mode="linear", align_corners=True)
|
| 114 |
+
coords = coords.transpose(1, 2).reshape(b, m, target_frames, 4)
|
| 115 |
+
return coords
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def build_layout_mask_from_bboxes(
|
| 119 |
+
bboxes: torch.Tensor,
|
| 120 |
+
grid_size: tuple[int, int, int],
|
| 121 |
+
image_size: tuple[int, int],
|
| 122 |
+
device: torch.device,
|
| 123 |
+
dtype: torch.dtype,
|
| 124 |
+
) -> torch.Tensor:
|
| 125 |
+
if bboxes is None:
|
| 126 |
+
return None
|
| 127 |
+
bboxes = bboxes.to(device=device, dtype=dtype)
|
| 128 |
+
b, m, f_layout, _ = bboxes.shape
|
| 129 |
+
f_grid, h_grid, w_grid = grid_size
|
| 130 |
+
height, width = image_size
|
| 131 |
+
layout = torch.zeros((b, m, f_grid, h_grid, w_grid), device=device, dtype=dtype)
|
| 132 |
+
for bi in range(b):
|
| 133 |
+
for mi in range(m):
|
| 134 |
+
for ti in range(f_layout):
|
| 135 |
+
pt = int(ti * f_grid / max(1, f_layout))
|
| 136 |
+
pt = max(0, min(f_grid - 1, pt))
|
| 137 |
+
x0, y0, x1, y1 = bboxes[bi, mi, ti]
|
| 138 |
+
x0 = float(x0)
|
| 139 |
+
y0 = float(y0)
|
| 140 |
+
x1 = float(x1)
|
| 141 |
+
y1 = float(y1)
|
| 142 |
+
if x1 <= x0 or y1 <= y0:
|
| 143 |
+
continue
|
| 144 |
+
px0 = int(math.floor(x0 / max(1.0, width) * w_grid))
|
| 145 |
+
px1 = int(math.ceil(x1 / max(1.0, width) * w_grid))
|
| 146 |
+
py0 = int(math.floor(y0 / max(1.0, height) * h_grid))
|
| 147 |
+
py1 = int(math.ceil(y1 / max(1.0, height) * h_grid))
|
| 148 |
+
px0 = max(0, min(w_grid, px0))
|
| 149 |
+
px1 = max(0, min(w_grid, px1))
|
| 150 |
+
py0 = max(0, min(h_grid, py0))
|
| 151 |
+
py1 = max(0, min(h_grid, py1))
|
| 152 |
+
if px1 <= px0 or py1 <= py0:
|
| 153 |
+
continue
|
| 154 |
+
layout[bi, mi, pt, py0:py1, px0:px1] = 1.0
|
| 155 |
+
return layout.flatten(2)
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
_lam_attention_call_count = [0] # 使用列表以便在函数内修改
|
| 159 |
+
|
| 160 |
+
def lam_attention(
|
| 161 |
+
q: torch.Tensor,
|
| 162 |
+
k: torch.Tensor,
|
| 163 |
+
v: torch.Tensor,
|
| 164 |
+
num_heads: int,
|
| 165 |
+
state: dict,
|
| 166 |
+
) -> Optional[torch.Tensor]:
|
| 167 |
+
subject_mask = state.get("subject_token_mask_lam")
|
| 168 |
+
if subject_mask is None:
|
| 169 |
+
subject_mask = state.get("subject_token_mask")
|
| 170 |
+
layout_mask = state.get("layout_mask")
|
| 171 |
+
state_token_mask = state.get("state_token_mask")
|
| 172 |
+
state_token_weights = state.get("state_token_weights")
|
| 173 |
+
state_scale = float(state.get("state_scale", 1.0))
|
| 174 |
+
grid_shape = state.get("grid_shape")
|
| 175 |
+
enable_lam = bool(state.get("enable_lam", False))
|
| 176 |
+
enable_state = state_token_mask is not None and state_token_weights is not None and grid_shape is not None
|
| 177 |
+
if not enable_lam and not enable_state:
|
| 178 |
+
return None
|
| 179 |
+
b, q_len, dim = q.shape
|
| 180 |
+
_, k_len, _ = k.shape
|
| 181 |
+
if enable_lam:
|
| 182 |
+
if subject_mask is None or layout_mask is None:
|
| 183 |
+
return None
|
| 184 |
+
if subject_mask.numel() == 0 or layout_mask.numel() == 0:
|
| 185 |
+
return None
|
| 186 |
+
if layout_mask.shape[-1] != q_len:
|
| 187 |
+
return None
|
| 188 |
+
if subject_mask.shape[-1] != k_len:
|
| 189 |
+
return None
|
| 190 |
+
if enable_state:
|
| 191 |
+
if state_token_mask.shape[-1] != k_len:
|
| 192 |
+
return None
|
| 193 |
+
head_dim = dim // num_heads
|
| 194 |
+
qh = q.view(b, q_len, num_heads, head_dim).transpose(1, 2)
|
| 195 |
+
kh = k.view(b, k_len, num_heads, head_dim).transpose(1, 2)
|
| 196 |
+
vh = v.view(b, k_len, num_heads, head_dim).transpose(1, 2)
|
| 197 |
+
attn_scores = torch.matmul(qh.float(), kh.float().transpose(-2, -1)) / math.sqrt(head_dim)
|
| 198 |
+
|
| 199 |
+
# ========== DEBUG: 打印 attention map 信息 ==========
|
| 200 |
+
_lam_attention_call_count[0] += 1
|
| 201 |
+
call_id = _lam_attention_call_count[0]
|
| 202 |
+
# 每100次调用打印一次,避免输出过多
|
| 203 |
+
if call_id % 100 == 1:
|
| 204 |
+
print(f"\n{'='*60}")
|
| 205 |
+
print(f"[LAM Attention #{call_id}]")
|
| 206 |
+
print(f" Q shape: {q.shape}, K shape: {k.shape}, V shape: {v.shape}")
|
| 207 |
+
print(f" num_heads: {num_heads}, head_dim: {head_dim}")
|
| 208 |
+
print(f" attn_scores shape: {attn_scores.shape}")
|
| 209 |
+
print(f" attn_scores stats: min={attn_scores.min().item():.4f}, max={attn_scores.max().item():.4f}, mean={attn_scores.mean().item():.4f}")
|
| 210 |
+
if enable_lam and layout_mask is not None:
|
| 211 |
+
print(f" layout_mask shape: {layout_mask.shape}")
|
| 212 |
+
print(f" layout_mask sum per subject: {layout_mask.sum(dim=-1)}")
|
| 213 |
+
if subject_mask is not None:
|
| 214 |
+
print(f" subject_token_mask shape: {subject_mask.shape}")
|
| 215 |
+
print(f" subject_token_mask active tokens per subject: {subject_mask.sum(dim=-1).tolist()}")
|
| 216 |
+
if grid_shape is not None:
|
| 217 |
+
print(f" grid_shape (f, h, w): {grid_shape}")
|
| 218 |
+
print(f"{'='*60}")
|
| 219 |
+
bias = torch.zeros_like(attn_scores)
|
| 220 |
+
if enable_lam:
|
| 221 |
+
attn_max = attn_scores.max(dim=-1, keepdim=True).values
|
| 222 |
+
attn_min = attn_scores.min(dim=-1, keepdim=True).values
|
| 223 |
+
g_plus = attn_max - attn_scores
|
| 224 |
+
g_minus = attn_min - attn_scores
|
| 225 |
+
subject_mask = subject_mask.to(device=attn_scores.device)
|
| 226 |
+
layout_mask = layout_mask.to(device=attn_scores.device, dtype=attn_scores.dtype)
|
| 227 |
+
apply_mask = state.get("apply_mask")
|
| 228 |
+
if apply_mask is not None:
|
| 229 |
+
layout_mask = layout_mask * apply_mask.to(device=layout_mask.device, dtype=layout_mask.dtype).view(-1, 1, 1)
|
| 230 |
+
subject_any = subject_mask.any(dim=0)
|
| 231 |
+
for k_idx in range(subject_mask.shape[0]):
|
| 232 |
+
mask_k = subject_mask[k_idx]
|
| 233 |
+
if not mask_k.any():
|
| 234 |
+
continue
|
| 235 |
+
mask_other = subject_any & (~mask_k)
|
| 236 |
+
mask_k = mask_k.to(dtype=attn_scores.dtype).view(1, 1, 1, k_len)
|
| 237 |
+
mask_other = mask_other.to(dtype=attn_scores.dtype).view(1, 1, 1, k_len)
|
| 238 |
+
g_k = g_plus * mask_k + g_minus * mask_other
|
| 239 |
+
attn_k = attn_scores[..., subject_mask[k_idx]].mean(dim=-1).mean(dim=1)
|
| 240 |
+
adapt_mask = attn_k >= attn_k.mean(dim=-1, keepdim=True)
|
| 241 |
+
layout_k = layout_mask[:, k_idx]
|
| 242 |
+
adapt_f = adapt_mask.to(layout_k.dtype)
|
| 243 |
+
inter = (adapt_f * layout_k).sum(dim=-1)
|
| 244 |
+
union = (adapt_f + layout_k - adapt_f * layout_k).sum(dim=-1)
|
| 245 |
+
iou = inter / union.clamp(min=1e-6)
|
| 246 |
+
strength = (1.0 - iou).view(b, 1, 1, 1)
|
| 247 |
+
bias = bias + g_k * strength * layout_k.view(b, 1, q_len, 1)
|
| 248 |
+
if enable_state:
|
| 249 |
+
f, h, w = grid_shape
|
| 250 |
+
if f * h * w != q_len:
|
| 251 |
+
return None
|
| 252 |
+
state_token_mask = state_token_mask.to(device=attn_scores.device)
|
| 253 |
+
state_indices = torch.nonzero(state_token_mask, as_tuple=False).flatten()
|
| 254 |
+
if state_indices.numel() == 0:
|
| 255 |
+
return None
|
| 256 |
+
weights = state_token_weights.to(device=attn_scores.device, dtype=attn_scores.dtype)
|
| 257 |
+
if weights.shape[1] != f:
|
| 258 |
+
return None
|
| 259 |
+
time_index = torch.arange(q_len, device=attn_scores.device) // (h * w)
|
| 260 |
+
weights_q = weights[:, time_index, :]
|
| 261 |
+
if weights_q.shape[-1] != state_indices.numel():
|
| 262 |
+
return None
|
| 263 |
+
state_bias = torch.zeros((b, 1, q_len, k_len), device=attn_scores.device, dtype=attn_scores.dtype)
|
| 264 |
+
state_bias[:, :, :, state_indices] = weights_q.unsqueeze(1) * state_scale
|
| 265 |
+
bias = bias + state_bias
|
| 266 |
+
attn_probs = torch.softmax(attn_scores + bias, dim=-1).to(vh.dtype)
|
| 267 |
+
|
| 268 |
+
# ========== DEBUG: 打印 attention probs 和 bias 信息 ==========
|
| 269 |
+
if _lam_attention_call_count[0] % 100 == 1:
|
| 270 |
+
print(f"\n[LAM Attention #{_lam_attention_call_count[0]} - After Bias]")
|
| 271 |
+
print(f" bias shape: {bias.shape}")
|
| 272 |
+
print(f" bias stats: min={bias.min().item():.4f}, max={bias.max().item():.4f}, mean={bias.mean().item():.4f}")
|
| 273 |
+
print(f" bias non-zero ratio: {(bias != 0).float().mean().item():.4f}")
|
| 274 |
+
print(f" attn_probs shape: {attn_probs.shape}")
|
| 275 |
+
print(f" attn_probs stats: min={attn_probs.min().item():.6f}, max={attn_probs.max().item():.6f}")
|
| 276 |
+
# 打印每个 subject 对应 token 的平均 attention weight
|
| 277 |
+
if subject_mask is not None:
|
| 278 |
+
for subj_idx in range(subject_mask.shape[0]):
|
| 279 |
+
mask_k = subject_mask[subj_idx]
|
| 280 |
+
if mask_k.any():
|
| 281 |
+
# 计算所有 query 对该 subject tokens 的平均 attention
|
| 282 |
+
subj_attn = attn_probs[:, :, :, mask_k.to(attn_probs.device)].mean()
|
| 283 |
+
print(f" Subject {subj_idx} avg attention weight: {subj_attn.item():.6f}")
|
| 284 |
+
print(f"{'='*60}\n")
|
| 285 |
+
|
| 286 |
+
out = torch.matmul(attn_probs, vh)
|
| 287 |
+
out = out.transpose(1, 2).reshape(b, q_len, dim)
|
| 288 |
+
return out
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
class CompAttnUnit(PipelineUnit):
|
| 292 |
+
def __init__(self):
|
| 293 |
+
super().__init__(
|
| 294 |
+
seperate_cfg=True,
|
| 295 |
+
input_params_posi={"prompt": "prompt", "context": "context"},
|
| 296 |
+
input_params_nega={"prompt": "negative_prompt", "context": "context"},
|
| 297 |
+
output_params=("comp_attn_state",),
|
| 298 |
+
onload_model_names=("text_encoder",),
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
def _clean_text(self, pipe, text: str) -> str:
|
| 302 |
+
if getattr(pipe.tokenizer, "clean", None):
|
| 303 |
+
return pipe.tokenizer._clean(text)
|
| 304 |
+
return text
|
| 305 |
+
|
| 306 |
+
def _tokenize_subject(self, pipe, text: str) -> torch.Tensor:
|
| 307 |
+
text = self._clean_text(pipe, text)
|
| 308 |
+
tokens = pipe.tokenizer.tokenizer(text, add_special_tokens=False, return_tensors="pt")
|
| 309 |
+
return tokens["input_ids"][0]
|
| 310 |
+
|
| 311 |
+
def _normalize_bboxes(self, bboxes: Sequence) -> torch.Tensor:
|
| 312 |
+
bboxes = torch.as_tensor(bboxes, dtype=torch.float32)
|
| 313 |
+
if bboxes.dim() == 2 and bboxes.shape[-1] == 4:
|
| 314 |
+
bboxes = bboxes.unsqueeze(0).unsqueeze(0)
|
| 315 |
+
elif bboxes.dim() == 3 and bboxes.shape[-1] == 4:
|
| 316 |
+
bboxes = bboxes.unsqueeze(0)
|
| 317 |
+
elif bboxes.dim() != 4 or bboxes.shape[-1] != 4:
|
| 318 |
+
raise ValueError(f"comp_attn_bboxes must be (..., 4), got shape {tuple(bboxes.shape)}")
|
| 319 |
+
return bboxes
|
| 320 |
+
|
| 321 |
+
def process(self, pipe, prompt, context) -> dict:
|
| 322 |
+
config: Optional[CompAttnConfig] = getattr(pipe, "_comp_attn_config", None)
|
| 323 |
+
if context is None or prompt is None or config is None:
|
| 324 |
+
return {}
|
| 325 |
+
if not config.subjects:
|
| 326 |
+
return {}
|
| 327 |
+
negative_prompt = getattr(pipe, "_comp_attn_last_negative_prompt", None)
|
| 328 |
+
if (not config.apply_to_negative) and negative_prompt and prompt == negative_prompt:
|
| 329 |
+
return {}
|
| 330 |
+
pipe.load_models_to_device(self.onload_model_names)
|
| 331 |
+
ids, mask = pipe.tokenizer(prompt, return_mask=True, add_special_tokens=True)
|
| 332 |
+
prompt_ids = ids[0]
|
| 333 |
+
valid_len = int(mask[0].sum().item())
|
| 334 |
+
indices_list = []
|
| 335 |
+
valid_subjects = []
|
| 336 |
+
for idx, subject in enumerate(config.subjects):
|
| 337 |
+
subject_ids = self._tokenize_subject(pipe, subject)
|
| 338 |
+
indices = find_subsequence_indices(prompt_ids, subject_ids, valid_len)
|
| 339 |
+
if not indices:
|
| 340 |
+
print(f"Comp-Attn: subject tokens not found in prompt: {subject}")
|
| 341 |
+
continue
|
| 342 |
+
indices_list.append(indices)
|
| 343 |
+
valid_subjects.append(idx)
|
| 344 |
+
if not indices_list:
|
| 345 |
+
return {}
|
| 346 |
+
subject_token_mask = build_subject_token_mask(indices_list, prompt_ids.shape[0]).to(device=context.device)
|
| 347 |
+
mask_float = subject_token_mask.to(dtype=context.dtype)
|
| 348 |
+
denom = mask_float.sum(dim=1, keepdim=True).clamp(min=1)
|
| 349 |
+
prompt_vecs = (mask_float @ context[0]) / denom
|
| 350 |
+
anchor_vecs = []
|
| 351 |
+
for idx in valid_subjects:
|
| 352 |
+
subject = config.subjects[idx]
|
| 353 |
+
sub_ids, sub_mask = pipe.tokenizer(subject, return_mask=True, add_special_tokens=True)
|
| 354 |
+
sub_ids = sub_ids.to(pipe.device)
|
| 355 |
+
sub_mask = sub_mask.to(pipe.device)
|
| 356 |
+
emb = pipe.text_encoder(sub_ids, sub_mask)
|
| 357 |
+
pooled = (emb * sub_mask.unsqueeze(-1)).sum(dim=1) / sub_mask.sum(dim=1, keepdim=True).clamp(min=1)
|
| 358 |
+
anchor_vecs.append(pooled)
|
| 359 |
+
anchor_vecs = torch.cat(anchor_vecs, dim=0)
|
| 360 |
+
saliency = compute_saliency(prompt_vecs.float(), anchor_vecs.float(), float(config.temperature)).to(prompt_vecs.dtype)
|
| 361 |
+
delta = compute_delta(anchor_vecs.to(prompt_vecs.dtype))
|
| 362 |
+
bboxes = None
|
| 363 |
+
state_vectors = None
|
| 364 |
+
state_weights = None
|
| 365 |
+
state_len = 0
|
| 366 |
+
if config.bboxes is not None:
|
| 367 |
+
bboxes = self._normalize_bboxes(config.bboxes)
|
| 368 |
+
if bboxes.shape[1] >= len(config.subjects):
|
| 369 |
+
bboxes = bboxes[:, valid_subjects]
|
| 370 |
+
if bboxes.shape[1] != len(valid_subjects):
|
| 371 |
+
print("Comp-Attn: bboxes subject count mismatch, disable LAM")
|
| 372 |
+
bboxes = None
|
| 373 |
+
if bboxes is not None and config.interpolate and getattr(pipe, "_comp_attn_num_frames", None) is not None:
|
| 374 |
+
bboxes = interpolate_bboxes(bboxes, int(pipe._comp_attn_num_frames))
|
| 375 |
+
if config.state_texts is not None and config.state_weights is not None:
|
| 376 |
+
state_texts = config.state_texts
|
| 377 |
+
if len(valid_subjects) != len(config.subjects):
|
| 378 |
+
subject_names = [config.subjects[i] for i in valid_subjects]
|
| 379 |
+
state_texts = [state_texts[i] for i in valid_subjects]
|
| 380 |
+
else:
|
| 381 |
+
subject_names = list(config.subjects)
|
| 382 |
+
if len(state_texts) != len(subject_names):
|
| 383 |
+
raise ValueError("state_texts must align with subjects")
|
| 384 |
+
state_count = len(state_texts[0])
|
| 385 |
+
for row in state_texts:
|
| 386 |
+
if len(row) != state_count:
|
| 387 |
+
raise ValueError("state_texts must have the same number of states per subject")
|
| 388 |
+
phrases = []
|
| 389 |
+
for subject, states in zip(subject_names, state_texts):
|
| 390 |
+
for state in states:
|
| 391 |
+
phrases.append(config.state_template.format(subject=subject, state=state))
|
| 392 |
+
ids, mask = pipe.tokenizer(phrases, return_mask=True, add_special_tokens=True)
|
| 393 |
+
ids = ids.to(pipe.device)
|
| 394 |
+
mask = mask.to(pipe.device)
|
| 395 |
+
emb = pipe.text_encoder(ids, mask)
|
| 396 |
+
pooled = (emb * mask.unsqueeze(-1)).sum(dim=1) / mask.sum(dim=1, keepdim=True).clamp(min=1)
|
| 397 |
+
state_vectors = pooled.to(dtype=prompt_vecs.dtype, device="cpu")
|
| 398 |
+
state_len = state_vectors.shape[0]
|
| 399 |
+
weights = torch.as_tensor(config.state_weights, dtype=torch.float32)
|
| 400 |
+
if weights.dim() == 3:
|
| 401 |
+
weights = weights.unsqueeze(0)
|
| 402 |
+
if weights.dim() != 4:
|
| 403 |
+
raise ValueError("state_weights must be (M,F,S) or (B,M,F,S)")
|
| 404 |
+
if weights.shape[1] >= len(config.subjects) and len(valid_subjects) != len(config.subjects):
|
| 405 |
+
weights = weights[:, valid_subjects]
|
| 406 |
+
if weights.shape[1] != len(subject_names) or weights.shape[3] != state_count:
|
| 407 |
+
raise ValueError("state_weights shape does not match state_texts")
|
| 408 |
+
weights = weights[:, :len(subject_names)]
|
| 409 |
+
weights = weights.permute(0, 2, 1, 3).contiguous()
|
| 410 |
+
weights = weights.reshape(weights.shape[0], weights.shape[1], weights.shape[2] * weights.shape[3])
|
| 411 |
+
state_weights = weights.to(device="cpu")
|
| 412 |
+
state = {
|
| 413 |
+
"enable_sci": bool(config.enable_sci),
|
| 414 |
+
"enable_lam": bool(config.enable_lam) and bboxes is not None,
|
| 415 |
+
"subject_token_mask": subject_token_mask,
|
| 416 |
+
"saliency": saliency,
|
| 417 |
+
"delta": delta,
|
| 418 |
+
"layout_bboxes": bboxes,
|
| 419 |
+
"state_vectors": state_vectors,
|
| 420 |
+
"state_weights": state_weights,
|
| 421 |
+
"state_scale": float(config.state_scale),
|
| 422 |
+
"prompt_len": int(prompt_ids.shape[0]),
|
| 423 |
+
"state_len": int(state_len),
|
| 424 |
+
"timestep_scale": 1000.0,
|
| 425 |
+
"apply_to_negative": bool(config.apply_to_negative),
|
| 426 |
+
}
|
| 427 |
+
if negative_prompt and prompt == negative_prompt:
|
| 428 |
+
pipe._comp_attn_state_neg = state
|
| 429 |
+
else:
|
| 430 |
+
pipe._comp_attn_state_pos = state
|
| 431 |
+
return {"comp_attn_state": state}
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
class CompAttnMergeUnit(PipelineUnit):
|
| 435 |
+
def __init__(self):
|
| 436 |
+
super().__init__(input_params=("cfg_merge",), output_params=("comp_attn_state",))
|
| 437 |
+
|
| 438 |
+
def process(self, pipe, cfg_merge) -> dict:
|
| 439 |
+
if not cfg_merge:
|
| 440 |
+
return {}
|
| 441 |
+
state_pos = getattr(pipe, "_comp_attn_state_pos", None)
|
| 442 |
+
state_neg = getattr(pipe, "_comp_attn_state_neg", None)
|
| 443 |
+
merged = state_pos or state_neg
|
| 444 |
+
if merged is None:
|
| 445 |
+
return {}
|
| 446 |
+
merged = dict(merged)
|
| 447 |
+
apply_to_negative = bool(merged.get("apply_to_negative", False))
|
| 448 |
+
merged["apply_mask"] = torch.tensor([1.0, 1.0 if apply_to_negative else 0.0])
|
| 449 |
+
return {"comp_attn_state": merged}
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
def patch_cross_attention(pipe) -> None:
|
| 453 |
+
for block in pipe.dit.blocks:
|
| 454 |
+
cross_attn = block.cross_attn
|
| 455 |
+
if getattr(cross_attn, "_comp_attn_patched", False):
|
| 456 |
+
continue
|
| 457 |
+
orig_forward = cross_attn.forward
|
| 458 |
+
|
| 459 |
+
def forward_with_lam(self, x, y, _orig=orig_forward, _pipe=pipe):
|
| 460 |
+
state = getattr(_pipe, "_comp_attn_runtime_state", None)
|
| 461 |
+
enable_lam = bool(state.get("enable_lam", False)) if state else False
|
| 462 |
+
enable_state = bool(state.get("state_token_weights") is not None) if state else False
|
| 463 |
+
if state is None or (not enable_lam and not enable_state):
|
| 464 |
+
return _orig(x, y)
|
| 465 |
+
if self.has_image_input:
|
| 466 |
+
img = y[:, :257]
|
| 467 |
+
ctx = y[:, 257:]
|
| 468 |
+
else:
|
| 469 |
+
ctx = y
|
| 470 |
+
q = self.norm_q(self.q(x))
|
| 471 |
+
k = self.norm_k(self.k(ctx))
|
| 472 |
+
v = self.v(ctx)
|
| 473 |
+
lam_out = lam_attention(q, k, v, self.num_heads, state)
|
| 474 |
+
if lam_out is None:
|
| 475 |
+
out = self.attn(q, k, v)
|
| 476 |
+
else:
|
| 477 |
+
out = lam_out
|
| 478 |
+
if self.has_image_input:
|
| 479 |
+
k_img = self.norm_k_img(self.k_img(img))
|
| 480 |
+
v_img = self.v_img(img)
|
| 481 |
+
img_out = self.attn(q, k_img, v_img)
|
| 482 |
+
out = out + img_out
|
| 483 |
+
return self.o(out)
|
| 484 |
+
|
| 485 |
+
cross_attn.forward = forward_with_lam.__get__(cross_attn, cross_attn.__class__)
|
| 486 |
+
cross_attn._comp_attn_patched = True
|
| 487 |
+
|
| 488 |
+
|
| 489 |
+
def get_grid_from_latents(latents: torch.Tensor, patch_size: tuple[int, int, int]) -> tuple[int, int, int]:
|
| 490 |
+
f = latents.shape[2] // patch_size[0]
|
| 491 |
+
h = latents.shape[3] // patch_size[1]
|
| 492 |
+
w = latents.shape[4] // patch_size[2]
|
| 493 |
+
return f, h, w
|
| 494 |
+
|
| 495 |
+
|
| 496 |
+
def wrap_model_fn(pipe) -> None:
|
| 497 |
+
if getattr(pipe, "_comp_attn_model_fn_patched", False):
|
| 498 |
+
return
|
| 499 |
+
orig_model_fn = pipe.model_fn
|
| 500 |
+
|
| 501 |
+
def model_fn_wrapper(*args, **kwargs):
|
| 502 |
+
comp_attn_state = kwargs.pop("comp_attn_state", None)
|
| 503 |
+
height = kwargs.get("height")
|
| 504 |
+
width = kwargs.get("width")
|
| 505 |
+
num_frames = kwargs.get("num_frames")
|
| 506 |
+
if num_frames is not None:
|
| 507 |
+
pipe._comp_attn_num_frames = num_frames
|
| 508 |
+
if comp_attn_state is None:
|
| 509 |
+
return orig_model_fn(*args, **kwargs)
|
| 510 |
+
latents = kwargs.get("latents")
|
| 511 |
+
timestep = kwargs.get("timestep")
|
| 512 |
+
context = kwargs.get("context")
|
| 513 |
+
clip_feature = kwargs.get("clip_feature")
|
| 514 |
+
reference_latents = kwargs.get("reference_latents")
|
| 515 |
+
state_vectors = comp_attn_state.get("state_vectors")
|
| 516 |
+
state_weights = comp_attn_state.get("state_weights")
|
| 517 |
+
state_len = int(comp_attn_state.get("state_len", 0))
|
| 518 |
+
prompt_len = int(comp_attn_state.get("prompt_len", context.shape[1] if context is not None else 0))
|
| 519 |
+
if context is not None and timestep is not None:
|
| 520 |
+
context = apply_sci(context, comp_attn_state, timestep)
|
| 521 |
+
if state_vectors is not None and state_len > 0:
|
| 522 |
+
state_vec = state_vectors.to(device=context.device, dtype=context.dtype)
|
| 523 |
+
if state_vec.dim() == 2:
|
| 524 |
+
state_vec = state_vec.unsqueeze(0)
|
| 525 |
+
if state_vec.shape[0] != context.shape[0]:
|
| 526 |
+
state_vec = state_vec.repeat(context.shape[0], 1, 1)
|
| 527 |
+
context = torch.cat([context, state_vec], dim=1)
|
| 528 |
+
kwargs["context"] = context
|
| 529 |
+
subject_mask = comp_attn_state.get("subject_token_mask")
|
| 530 |
+
if subject_mask is not None:
|
| 531 |
+
clip_len = clip_feature.shape[1] if clip_feature is not None and pipe.dit.require_clip_embedding else 0
|
| 532 |
+
pad_clip = torch.zeros((subject_mask.shape[0], clip_len), dtype=torch.bool)
|
| 533 |
+
pad_state = torch.zeros((subject_mask.shape[0], state_len), dtype=torch.bool)
|
| 534 |
+
comp_attn_state["subject_token_mask_lam"] = torch.cat([pad_clip, subject_mask.cpu(), pad_state], dim=1)
|
| 535 |
+
if state_vectors is not None and state_len > 0:
|
| 536 |
+
clip_len = clip_feature.shape[1] if clip_feature is not None and pipe.dit.require_clip_embedding else 0
|
| 537 |
+
pad_prompt = torch.zeros((state_len, clip_len + prompt_len), dtype=torch.bool)
|
| 538 |
+
ones_state = torch.ones((state_len, state_len), dtype=torch.bool)
|
| 539 |
+
state_token_mask = torch.cat([pad_prompt, ones_state], dim=1).any(dim=0)
|
| 540 |
+
comp_attn_state["state_token_mask"] = state_token_mask
|
| 541 |
+
if latents is not None and height is not None and width is not None:
|
| 542 |
+
f, h, w = get_grid_from_latents(latents, pipe.dit.patch_size)
|
| 543 |
+
if comp_attn_state.get("enable_lam", False):
|
| 544 |
+
q_len = f * h * w
|
| 545 |
+
if reference_latents is not None:
|
| 546 |
+
q_len = (f + 1) * h * w
|
| 547 |
+
layout_mask = comp_attn_state.get("layout_mask")
|
| 548 |
+
layout_shape = comp_attn_state.get("layout_shape")
|
| 549 |
+
if layout_mask is None or layout_shape != (latents.shape[0], q_len):
|
| 550 |
+
layout_mask = build_layout_mask_from_bboxes(
|
| 551 |
+
comp_attn_state.get("layout_bboxes"),
|
| 552 |
+
(f, h, w),
|
| 553 |
+
(int(height), int(width)),
|
| 554 |
+
device=latents.device,
|
| 555 |
+
dtype=latents.dtype,
|
| 556 |
+
)
|
| 557 |
+
if reference_latents is not None:
|
| 558 |
+
pad = torch.zeros((layout_mask.shape[0], layout_mask.shape[1], h * w), device=latents.device, dtype=latents.dtype)
|
| 559 |
+
layout_mask = torch.cat([pad, layout_mask], dim=-1)
|
| 560 |
+
if layout_mask.shape[0] != latents.shape[0]:
|
| 561 |
+
layout_mask = layout_mask.repeat(latents.shape[0], 1, 1)
|
| 562 |
+
comp_attn_state["layout_mask"] = layout_mask
|
| 563 |
+
comp_attn_state["layout_shape"] = (latents.shape[0], q_len)
|
| 564 |
+
if state_weights is not None:
|
| 565 |
+
weights = state_weights.to(device=latents.device, dtype=latents.dtype)
|
| 566 |
+
if weights.shape[0] != latents.shape[0]:
|
| 567 |
+
weights = weights.repeat(latents.shape[0], 1, 1)
|
| 568 |
+
if weights.shape[1] != f:
|
| 569 |
+
weights = weights.transpose(1, 2)
|
| 570 |
+
weights = F.interpolate(weights, size=f, mode="linear", align_corners=True)
|
| 571 |
+
weights = weights.transpose(1, 2)
|
| 572 |
+
if reference_latents is not None:
|
| 573 |
+
pad = torch.zeros((weights.shape[0], 1, weights.shape[2]), device=weights.device, dtype=weights.dtype)
|
| 574 |
+
weights = torch.cat([pad, weights], dim=1)
|
| 575 |
+
f = f + 1
|
| 576 |
+
comp_attn_state["state_token_weights"] = weights
|
| 577 |
+
comp_attn_state["grid_shape"] = (f, h, w)
|
| 578 |
+
if (
|
| 579 |
+
latents is not None
|
| 580 |
+
and latents.shape[0] == 2
|
| 581 |
+
and not comp_attn_state.get("apply_to_negative", False)
|
| 582 |
+
and "apply_mask" not in comp_attn_state
|
| 583 |
+
):
|
| 584 |
+
comp_attn_state["apply_mask"] = torch.tensor([1.0, 0.0], device=latents.device, dtype=latents.dtype)
|
| 585 |
+
pipe._comp_attn_runtime_state = comp_attn_state
|
| 586 |
+
try:
|
| 587 |
+
return orig_model_fn(*args, **kwargs)
|
| 588 |
+
finally:
|
| 589 |
+
pipe._comp_attn_runtime_state = None
|
| 590 |
+
|
| 591 |
+
pipe.model_fn = model_fn_wrapper
|
| 592 |
+
pipe._comp_attn_model_fn_patched = True
|
diffsynth/models/dinov3_image_encoder.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import DINOv3ViTModel, DINOv3ViTImageProcessorFast
|
| 2 |
+
from transformers.models.dinov3_vit.modeling_dinov3_vit import DINOv3ViTConfig
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class DINOv3ImageEncoder(DINOv3ViTModel):
|
| 7 |
+
def __init__(self):
|
| 8 |
+
config = DINOv3ViTConfig(
|
| 9 |
+
architectures = [
|
| 10 |
+
"DINOv3ViTModel"
|
| 11 |
+
],
|
| 12 |
+
attention_dropout = 0.0,
|
| 13 |
+
drop_path_rate = 0.0,
|
| 14 |
+
dtype = "float32",
|
| 15 |
+
hidden_act = "silu",
|
| 16 |
+
hidden_size = 4096,
|
| 17 |
+
image_size = 224,
|
| 18 |
+
initializer_range = 0.02,
|
| 19 |
+
intermediate_size = 8192,
|
| 20 |
+
key_bias = False,
|
| 21 |
+
layer_norm_eps = 1e-05,
|
| 22 |
+
layerscale_value = 1.0,
|
| 23 |
+
mlp_bias = True,
|
| 24 |
+
model_type = "dinov3_vit",
|
| 25 |
+
num_attention_heads = 32,
|
| 26 |
+
num_channels = 3,
|
| 27 |
+
num_hidden_layers = 40,
|
| 28 |
+
num_register_tokens = 4,
|
| 29 |
+
patch_size = 16,
|
| 30 |
+
pos_embed_jitter = None,
|
| 31 |
+
pos_embed_rescale = 2.0,
|
| 32 |
+
pos_embed_shift = None,
|
| 33 |
+
proj_bias = True,
|
| 34 |
+
query_bias = False,
|
| 35 |
+
rope_theta = 100.0,
|
| 36 |
+
transformers_version = "4.56.1",
|
| 37 |
+
use_gated_mlp = True,
|
| 38 |
+
value_bias = False
|
| 39 |
+
)
|
| 40 |
+
super().__init__(config)
|
| 41 |
+
self.processor = DINOv3ViTImageProcessorFast(
|
| 42 |
+
crop_size = None,
|
| 43 |
+
data_format = "channels_first",
|
| 44 |
+
default_to_square = True,
|
| 45 |
+
device = None,
|
| 46 |
+
disable_grouping = None,
|
| 47 |
+
do_center_crop = None,
|
| 48 |
+
do_convert_rgb = None,
|
| 49 |
+
do_normalize = True,
|
| 50 |
+
do_rescale = True,
|
| 51 |
+
do_resize = True,
|
| 52 |
+
image_mean = [
|
| 53 |
+
0.485,
|
| 54 |
+
0.456,
|
| 55 |
+
0.406
|
| 56 |
+
],
|
| 57 |
+
image_processor_type = "DINOv3ViTImageProcessorFast",
|
| 58 |
+
image_std = [
|
| 59 |
+
0.229,
|
| 60 |
+
0.224,
|
| 61 |
+
0.225
|
| 62 |
+
],
|
| 63 |
+
input_data_format = None,
|
| 64 |
+
resample = 2,
|
| 65 |
+
rescale_factor = 0.00392156862745098,
|
| 66 |
+
return_tensors = None,
|
| 67 |
+
size = {
|
| 68 |
+
"height": 224,
|
| 69 |
+
"width": 224
|
| 70 |
+
}
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
def forward(self, image, torch_dtype=torch.bfloat16, device="cuda"):
|
| 74 |
+
inputs = self.processor(images=image, return_tensors="pt")
|
| 75 |
+
pixel_values = inputs["pixel_values"].to(dtype=torch_dtype, device=device)
|
| 76 |
+
bool_masked_pos = None
|
| 77 |
+
head_mask = None
|
| 78 |
+
|
| 79 |
+
pixel_values = pixel_values.to(torch_dtype)
|
| 80 |
+
hidden_states = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
|
| 81 |
+
position_embeddings = self.rope_embeddings(pixel_values)
|
| 82 |
+
|
| 83 |
+
for i, layer_module in enumerate(self.layer):
|
| 84 |
+
layer_head_mask = head_mask[i] if head_mask is not None else None
|
| 85 |
+
hidden_states = layer_module(
|
| 86 |
+
hidden_states,
|
| 87 |
+
attention_mask=layer_head_mask,
|
| 88 |
+
position_embeddings=position_embeddings,
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
sequence_output = self.norm(hidden_states)
|
| 92 |
+
pooled_output = sequence_output[:, 0, :]
|
| 93 |
+
|
| 94 |
+
return pooled_output
|
diffsynth/models/flux2_dit.py
ADDED
|
@@ -0,0 +1,1057 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import inspect
|
| 2 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 3 |
+
|
| 4 |
+
import torch, math
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
import numpy as np
|
| 8 |
+
from ..core.attention import attention_forward
|
| 9 |
+
from ..core.gradient import gradient_checkpoint_forward
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def get_timestep_embedding(
|
| 13 |
+
timesteps: torch.Tensor,
|
| 14 |
+
embedding_dim: int,
|
| 15 |
+
flip_sin_to_cos: bool = False,
|
| 16 |
+
downscale_freq_shift: float = 1,
|
| 17 |
+
scale: float = 1,
|
| 18 |
+
max_period: int = 10000,
|
| 19 |
+
) -> torch.Tensor:
|
| 20 |
+
"""
|
| 21 |
+
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
| 22 |
+
|
| 23 |
+
Args
|
| 24 |
+
timesteps (torch.Tensor):
|
| 25 |
+
a 1-D Tensor of N indices, one per batch element. These may be fractional.
|
| 26 |
+
embedding_dim (int):
|
| 27 |
+
the dimension of the output.
|
| 28 |
+
flip_sin_to_cos (bool):
|
| 29 |
+
Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
|
| 30 |
+
downscale_freq_shift (float):
|
| 31 |
+
Controls the delta between frequencies between dimensions
|
| 32 |
+
scale (float):
|
| 33 |
+
Scaling factor applied to the embeddings.
|
| 34 |
+
max_period (int):
|
| 35 |
+
Controls the maximum frequency of the embeddings
|
| 36 |
+
Returns
|
| 37 |
+
torch.Tensor: an [N x dim] Tensor of positional embeddings.
|
| 38 |
+
"""
|
| 39 |
+
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
| 40 |
+
|
| 41 |
+
half_dim = embedding_dim // 2
|
| 42 |
+
exponent = -math.log(max_period) * torch.arange(
|
| 43 |
+
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
|
| 44 |
+
)
|
| 45 |
+
exponent = exponent / (half_dim - downscale_freq_shift)
|
| 46 |
+
|
| 47 |
+
emb = torch.exp(exponent)
|
| 48 |
+
emb = timesteps[:, None].float() * emb[None, :]
|
| 49 |
+
|
| 50 |
+
# scale embeddings
|
| 51 |
+
emb = scale * emb
|
| 52 |
+
|
| 53 |
+
# concat sine and cosine embeddings
|
| 54 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
| 55 |
+
|
| 56 |
+
# flip sine and cosine embeddings
|
| 57 |
+
if flip_sin_to_cos:
|
| 58 |
+
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
|
| 59 |
+
|
| 60 |
+
# zero pad
|
| 61 |
+
if embedding_dim % 2 == 1:
|
| 62 |
+
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
| 63 |
+
return emb
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class TimestepEmbedding(nn.Module):
|
| 67 |
+
def __init__(
|
| 68 |
+
self,
|
| 69 |
+
in_channels: int,
|
| 70 |
+
time_embed_dim: int,
|
| 71 |
+
act_fn: str = "silu",
|
| 72 |
+
out_dim: int = None,
|
| 73 |
+
post_act_fn: Optional[str] = None,
|
| 74 |
+
cond_proj_dim=None,
|
| 75 |
+
sample_proj_bias=True,
|
| 76 |
+
):
|
| 77 |
+
super().__init__()
|
| 78 |
+
|
| 79 |
+
self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias)
|
| 80 |
+
|
| 81 |
+
if cond_proj_dim is not None:
|
| 82 |
+
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
|
| 83 |
+
else:
|
| 84 |
+
self.cond_proj = None
|
| 85 |
+
|
| 86 |
+
self.act = torch.nn.SiLU()
|
| 87 |
+
|
| 88 |
+
if out_dim is not None:
|
| 89 |
+
time_embed_dim_out = out_dim
|
| 90 |
+
else:
|
| 91 |
+
time_embed_dim_out = time_embed_dim
|
| 92 |
+
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias)
|
| 93 |
+
|
| 94 |
+
if post_act_fn is None:
|
| 95 |
+
self.post_act = None
|
| 96 |
+
|
| 97 |
+
def forward(self, sample, condition=None):
|
| 98 |
+
if condition is not None:
|
| 99 |
+
sample = sample + self.cond_proj(condition)
|
| 100 |
+
sample = self.linear_1(sample)
|
| 101 |
+
|
| 102 |
+
if self.act is not None:
|
| 103 |
+
sample = self.act(sample)
|
| 104 |
+
|
| 105 |
+
sample = self.linear_2(sample)
|
| 106 |
+
|
| 107 |
+
if self.post_act is not None:
|
| 108 |
+
sample = self.post_act(sample)
|
| 109 |
+
return sample
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class Timesteps(nn.Module):
|
| 113 |
+
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1):
|
| 114 |
+
super().__init__()
|
| 115 |
+
self.num_channels = num_channels
|
| 116 |
+
self.flip_sin_to_cos = flip_sin_to_cos
|
| 117 |
+
self.downscale_freq_shift = downscale_freq_shift
|
| 118 |
+
self.scale = scale
|
| 119 |
+
|
| 120 |
+
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
|
| 121 |
+
t_emb = get_timestep_embedding(
|
| 122 |
+
timesteps,
|
| 123 |
+
self.num_channels,
|
| 124 |
+
flip_sin_to_cos=self.flip_sin_to_cos,
|
| 125 |
+
downscale_freq_shift=self.downscale_freq_shift,
|
| 126 |
+
scale=self.scale,
|
| 127 |
+
)
|
| 128 |
+
return t_emb
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class AdaLayerNormContinuous(nn.Module):
|
| 132 |
+
r"""
|
| 133 |
+
Adaptive normalization layer with a norm layer (layer_norm or rms_norm).
|
| 134 |
+
|
| 135 |
+
Args:
|
| 136 |
+
embedding_dim (`int`): Embedding dimension to use during projection.
|
| 137 |
+
conditioning_embedding_dim (`int`): Dimension of the input condition.
|
| 138 |
+
elementwise_affine (`bool`, defaults to `True`):
|
| 139 |
+
Boolean flag to denote if affine transformation should be applied.
|
| 140 |
+
eps (`float`, defaults to 1e-5): Epsilon factor.
|
| 141 |
+
bias (`bias`, defaults to `True`): Boolean flag to denote if bias should be use.
|
| 142 |
+
norm_type (`str`, defaults to `"layer_norm"`):
|
| 143 |
+
Normalization layer to use. Values supported: "layer_norm", "rms_norm".
|
| 144 |
+
"""
|
| 145 |
+
|
| 146 |
+
def __init__(
|
| 147 |
+
self,
|
| 148 |
+
embedding_dim: int,
|
| 149 |
+
conditioning_embedding_dim: int,
|
| 150 |
+
# NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
|
| 151 |
+
# because the output is immediately scaled and shifted by the projected conditioning embeddings.
|
| 152 |
+
# Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
|
| 153 |
+
# However, this is how it was implemented in the original code, and it's rather likely you should
|
| 154 |
+
# set `elementwise_affine` to False.
|
| 155 |
+
elementwise_affine=True,
|
| 156 |
+
eps=1e-5,
|
| 157 |
+
bias=True,
|
| 158 |
+
norm_type="layer_norm",
|
| 159 |
+
):
|
| 160 |
+
super().__init__()
|
| 161 |
+
self.silu = nn.SiLU()
|
| 162 |
+
self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
|
| 163 |
+
if norm_type == "layer_norm":
|
| 164 |
+
self.norm = nn.LayerNorm(embedding_dim, eps, elementwise_affine, bias)
|
| 165 |
+
|
| 166 |
+
def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
|
| 167 |
+
# convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
|
| 168 |
+
emb = self.linear(self.silu(conditioning_embedding).to(x.dtype))
|
| 169 |
+
scale, shift = torch.chunk(emb, 2, dim=1)
|
| 170 |
+
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
|
| 171 |
+
return x
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def get_1d_rotary_pos_embed(
|
| 175 |
+
dim: int,
|
| 176 |
+
pos: Union[np.ndarray, int],
|
| 177 |
+
theta: float = 10000.0,
|
| 178 |
+
use_real=False,
|
| 179 |
+
linear_factor=1.0,
|
| 180 |
+
ntk_factor=1.0,
|
| 181 |
+
repeat_interleave_real=True,
|
| 182 |
+
freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux)
|
| 183 |
+
):
|
| 184 |
+
"""
|
| 185 |
+
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
|
| 186 |
+
|
| 187 |
+
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
|
| 188 |
+
index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
|
| 189 |
+
data type.
|
| 190 |
+
|
| 191 |
+
Args:
|
| 192 |
+
dim (`int`): Dimension of the frequency tensor.
|
| 193 |
+
pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
|
| 194 |
+
theta (`float`, *optional*, defaults to 10000.0):
|
| 195 |
+
Scaling factor for frequency computation. Defaults to 10000.0.
|
| 196 |
+
use_real (`bool`, *optional*):
|
| 197 |
+
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
| 198 |
+
linear_factor (`float`, *optional*, defaults to 1.0):
|
| 199 |
+
Scaling factor for the context extrapolation. Defaults to 1.0.
|
| 200 |
+
ntk_factor (`float`, *optional*, defaults to 1.0):
|
| 201 |
+
Scaling factor for the NTK-Aware RoPE. Defaults to 1.0.
|
| 202 |
+
repeat_interleave_real (`bool`, *optional*, defaults to `True`):
|
| 203 |
+
If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`.
|
| 204 |
+
Otherwise, they are concateanted with themselves.
|
| 205 |
+
freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`):
|
| 206 |
+
the dtype of the frequency tensor.
|
| 207 |
+
Returns:
|
| 208 |
+
`torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
|
| 209 |
+
"""
|
| 210 |
+
assert dim % 2 == 0
|
| 211 |
+
|
| 212 |
+
if isinstance(pos, int):
|
| 213 |
+
pos = torch.arange(pos)
|
| 214 |
+
if isinstance(pos, np.ndarray):
|
| 215 |
+
pos = torch.from_numpy(pos) # type: ignore # [S]
|
| 216 |
+
|
| 217 |
+
theta = theta * ntk_factor
|
| 218 |
+
freqs = (
|
| 219 |
+
1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device) / dim)) / linear_factor
|
| 220 |
+
) # [D/2]
|
| 221 |
+
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
|
| 222 |
+
is_npu = freqs.device.type == "npu"
|
| 223 |
+
if is_npu:
|
| 224 |
+
freqs = freqs.float()
|
| 225 |
+
if use_real and repeat_interleave_real:
|
| 226 |
+
# flux, hunyuan-dit, cogvideox
|
| 227 |
+
freqs_cos = freqs.cos().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D]
|
| 228 |
+
freqs_sin = freqs.sin().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D]
|
| 229 |
+
return freqs_cos, freqs_sin
|
| 230 |
+
elif use_real:
|
| 231 |
+
# stable audio, allegro
|
| 232 |
+
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
|
| 233 |
+
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
|
| 234 |
+
return freqs_cos, freqs_sin
|
| 235 |
+
else:
|
| 236 |
+
# lumina
|
| 237 |
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
|
| 238 |
+
return freqs_cis
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def apply_rotary_emb(
|
| 242 |
+
x: torch.Tensor,
|
| 243 |
+
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
| 244 |
+
use_real: bool = True,
|
| 245 |
+
use_real_unbind_dim: int = -1,
|
| 246 |
+
sequence_dim: int = 2,
|
| 247 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 248 |
+
"""
|
| 249 |
+
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
|
| 250 |
+
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
|
| 251 |
+
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
|
| 252 |
+
tensors contain rotary embeddings and are returned as real tensors.
|
| 253 |
+
|
| 254 |
+
Args:
|
| 255 |
+
x (`torch.Tensor`):
|
| 256 |
+
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
|
| 257 |
+
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
|
| 258 |
+
|
| 259 |
+
Returns:
|
| 260 |
+
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
| 261 |
+
"""
|
| 262 |
+
if use_real:
|
| 263 |
+
cos, sin = freqs_cis # [S, D]
|
| 264 |
+
if sequence_dim == 2:
|
| 265 |
+
cos = cos[None, None, :, :]
|
| 266 |
+
sin = sin[None, None, :, :]
|
| 267 |
+
elif sequence_dim == 1:
|
| 268 |
+
cos = cos[None, :, None, :]
|
| 269 |
+
sin = sin[None, :, None, :]
|
| 270 |
+
else:
|
| 271 |
+
raise ValueError(f"`sequence_dim={sequence_dim}` but should be 1 or 2.")
|
| 272 |
+
|
| 273 |
+
cos, sin = cos.to(x.device), sin.to(x.device)
|
| 274 |
+
|
| 275 |
+
if use_real_unbind_dim == -1:
|
| 276 |
+
# Used for flux, cogvideox, hunyuan-dit
|
| 277 |
+
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, H, S, D//2]
|
| 278 |
+
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
| 279 |
+
elif use_real_unbind_dim == -2:
|
| 280 |
+
# Used for Stable Audio, OmniGen, CogView4 and Cosmos
|
| 281 |
+
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, H, S, D//2]
|
| 282 |
+
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
|
| 283 |
+
else:
|
| 284 |
+
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
|
| 285 |
+
|
| 286 |
+
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
| 287 |
+
|
| 288 |
+
return out
|
| 289 |
+
else:
|
| 290 |
+
# used for lumina
|
| 291 |
+
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
| 292 |
+
freqs_cis = freqs_cis.unsqueeze(2)
|
| 293 |
+
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
|
| 294 |
+
|
| 295 |
+
return x_out.type_as(x)
|
| 296 |
+
|
| 297 |
+
def _get_projections(attn: "Flux2Attention", hidden_states, encoder_hidden_states=None):
|
| 298 |
+
query = attn.to_q(hidden_states)
|
| 299 |
+
key = attn.to_k(hidden_states)
|
| 300 |
+
value = attn.to_v(hidden_states)
|
| 301 |
+
|
| 302 |
+
encoder_query = encoder_key = encoder_value = None
|
| 303 |
+
if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
|
| 304 |
+
encoder_query = attn.add_q_proj(encoder_hidden_states)
|
| 305 |
+
encoder_key = attn.add_k_proj(encoder_hidden_states)
|
| 306 |
+
encoder_value = attn.add_v_proj(encoder_hidden_states)
|
| 307 |
+
|
| 308 |
+
return query, key, value, encoder_query, encoder_key, encoder_value
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
def _get_fused_projections(attn: "Flux2Attention", hidden_states, encoder_hidden_states=None):
|
| 312 |
+
query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
|
| 313 |
+
|
| 314 |
+
encoder_query = encoder_key = encoder_value = (None,)
|
| 315 |
+
if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"):
|
| 316 |
+
encoder_query, encoder_key, encoder_value = attn.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1)
|
| 317 |
+
|
| 318 |
+
return query, key, value, encoder_query, encoder_key, encoder_value
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
def _get_qkv_projections(attn: "Flux2Attention", hidden_states, encoder_hidden_states=None):
|
| 322 |
+
return _get_projections(attn, hidden_states, encoder_hidden_states)
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
class Flux2SwiGLU(nn.Module):
|
| 326 |
+
"""
|
| 327 |
+
Flux 2 uses a SwiGLU-style activation in the transformer feedforward sub-blocks, but with the linear projection
|
| 328 |
+
layer fused into the first linear layer of the FF sub-block. Thus, this module has no trainable parameters.
|
| 329 |
+
"""
|
| 330 |
+
|
| 331 |
+
def __init__(self):
|
| 332 |
+
super().__init__()
|
| 333 |
+
self.gate_fn = nn.SiLU()
|
| 334 |
+
|
| 335 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 336 |
+
x1, x2 = x.chunk(2, dim=-1)
|
| 337 |
+
x = self.gate_fn(x1) * x2
|
| 338 |
+
return x
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
class Flux2FeedForward(nn.Module):
|
| 342 |
+
def __init__(
|
| 343 |
+
self,
|
| 344 |
+
dim: int,
|
| 345 |
+
dim_out: Optional[int] = None,
|
| 346 |
+
mult: float = 3.0,
|
| 347 |
+
inner_dim: Optional[int] = None,
|
| 348 |
+
bias: bool = False,
|
| 349 |
+
):
|
| 350 |
+
super().__init__()
|
| 351 |
+
if inner_dim is None:
|
| 352 |
+
inner_dim = int(dim * mult)
|
| 353 |
+
dim_out = dim_out or dim
|
| 354 |
+
|
| 355 |
+
# Flux2SwiGLU will reduce the dimension by half
|
| 356 |
+
self.linear_in = nn.Linear(dim, inner_dim * 2, bias=bias)
|
| 357 |
+
self.act_fn = Flux2SwiGLU()
|
| 358 |
+
self.linear_out = nn.Linear(inner_dim, dim_out, bias=bias)
|
| 359 |
+
|
| 360 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 361 |
+
x = self.linear_in(x)
|
| 362 |
+
x = self.act_fn(x)
|
| 363 |
+
x = self.linear_out(x)
|
| 364 |
+
return x
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
class Flux2AttnProcessor:
|
| 368 |
+
_attention_backend = None
|
| 369 |
+
_parallel_config = None
|
| 370 |
+
|
| 371 |
+
def __init__(self):
|
| 372 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
| 373 |
+
raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
|
| 374 |
+
|
| 375 |
+
def __call__(
|
| 376 |
+
self,
|
| 377 |
+
attn: "Flux2Attention",
|
| 378 |
+
hidden_states: torch.Tensor,
|
| 379 |
+
encoder_hidden_states: torch.Tensor = None,
|
| 380 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 381 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
| 382 |
+
) -> torch.Tensor:
|
| 383 |
+
query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
|
| 384 |
+
attn, hidden_states, encoder_hidden_states
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
query = query.unflatten(-1, (attn.heads, -1))
|
| 388 |
+
key = key.unflatten(-1, (attn.heads, -1))
|
| 389 |
+
value = value.unflatten(-1, (attn.heads, -1))
|
| 390 |
+
|
| 391 |
+
query = attn.norm_q(query)
|
| 392 |
+
key = attn.norm_k(key)
|
| 393 |
+
|
| 394 |
+
if attn.added_kv_proj_dim is not None:
|
| 395 |
+
encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
|
| 396 |
+
encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
|
| 397 |
+
encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
|
| 398 |
+
|
| 399 |
+
encoder_query = attn.norm_added_q(encoder_query)
|
| 400 |
+
encoder_key = attn.norm_added_k(encoder_key)
|
| 401 |
+
|
| 402 |
+
query = torch.cat([encoder_query, query], dim=1)
|
| 403 |
+
key = torch.cat([encoder_key, key], dim=1)
|
| 404 |
+
value = torch.cat([encoder_value, value], dim=1)
|
| 405 |
+
|
| 406 |
+
if image_rotary_emb is not None:
|
| 407 |
+
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
|
| 408 |
+
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
|
| 409 |
+
|
| 410 |
+
hidden_states = attention_forward(
|
| 411 |
+
query,
|
| 412 |
+
key,
|
| 413 |
+
value,
|
| 414 |
+
q_pattern="b s n d", k_pattern="b s n d", v_pattern="b s n d", out_pattern="b s n d",
|
| 415 |
+
)
|
| 416 |
+
hidden_states = hidden_states.flatten(2, 3)
|
| 417 |
+
hidden_states = hidden_states.to(query.dtype)
|
| 418 |
+
|
| 419 |
+
if encoder_hidden_states is not None:
|
| 420 |
+
encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
|
| 421 |
+
[encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
|
| 422 |
+
)
|
| 423 |
+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
| 424 |
+
|
| 425 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 426 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 427 |
+
|
| 428 |
+
if encoder_hidden_states is not None:
|
| 429 |
+
return hidden_states, encoder_hidden_states
|
| 430 |
+
else:
|
| 431 |
+
return hidden_states
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
class Flux2Attention(torch.nn.Module):
|
| 435 |
+
_default_processor_cls = Flux2AttnProcessor
|
| 436 |
+
_available_processors = [Flux2AttnProcessor]
|
| 437 |
+
|
| 438 |
+
def __init__(
|
| 439 |
+
self,
|
| 440 |
+
query_dim: int,
|
| 441 |
+
heads: int = 8,
|
| 442 |
+
dim_head: int = 64,
|
| 443 |
+
dropout: float = 0.0,
|
| 444 |
+
bias: bool = False,
|
| 445 |
+
added_kv_proj_dim: Optional[int] = None,
|
| 446 |
+
added_proj_bias: Optional[bool] = True,
|
| 447 |
+
out_bias: bool = True,
|
| 448 |
+
eps: float = 1e-5,
|
| 449 |
+
out_dim: int = None,
|
| 450 |
+
elementwise_affine: bool = True,
|
| 451 |
+
processor=None,
|
| 452 |
+
):
|
| 453 |
+
super().__init__()
|
| 454 |
+
|
| 455 |
+
self.head_dim = dim_head
|
| 456 |
+
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
| 457 |
+
self.query_dim = query_dim
|
| 458 |
+
self.out_dim = out_dim if out_dim is not None else query_dim
|
| 459 |
+
self.heads = out_dim // dim_head if out_dim is not None else heads
|
| 460 |
+
|
| 461 |
+
self.use_bias = bias
|
| 462 |
+
self.dropout = dropout
|
| 463 |
+
|
| 464 |
+
self.added_kv_proj_dim = added_kv_proj_dim
|
| 465 |
+
self.added_proj_bias = added_proj_bias
|
| 466 |
+
|
| 467 |
+
self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
|
| 468 |
+
self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
|
| 469 |
+
self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
|
| 470 |
+
|
| 471 |
+
# QK Norm
|
| 472 |
+
self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
| 473 |
+
self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
| 474 |
+
|
| 475 |
+
self.to_out = torch.nn.ModuleList([])
|
| 476 |
+
self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
|
| 477 |
+
self.to_out.append(torch.nn.Dropout(dropout))
|
| 478 |
+
|
| 479 |
+
if added_kv_proj_dim is not None:
|
| 480 |
+
self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps)
|
| 481 |
+
self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps)
|
| 482 |
+
self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
| 483 |
+
self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
| 484 |
+
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
| 485 |
+
self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias)
|
| 486 |
+
|
| 487 |
+
if processor is None:
|
| 488 |
+
processor = self._default_processor_cls()
|
| 489 |
+
self.processor = processor
|
| 490 |
+
|
| 491 |
+
def forward(
|
| 492 |
+
self,
|
| 493 |
+
hidden_states: torch.Tensor,
|
| 494 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 495 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 496 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
| 497 |
+
**kwargs,
|
| 498 |
+
) -> torch.Tensor:
|
| 499 |
+
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
|
| 500 |
+
kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
|
| 501 |
+
return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
|
| 502 |
+
|
| 503 |
+
|
| 504 |
+
class Flux2ParallelSelfAttnProcessor:
|
| 505 |
+
_attention_backend = None
|
| 506 |
+
_parallel_config = None
|
| 507 |
+
|
| 508 |
+
def __init__(self):
|
| 509 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
| 510 |
+
raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
|
| 511 |
+
|
| 512 |
+
def __call__(
|
| 513 |
+
self,
|
| 514 |
+
attn: "Flux2ParallelSelfAttention",
|
| 515 |
+
hidden_states: torch.Tensor,
|
| 516 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 517 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
| 518 |
+
) -> torch.Tensor:
|
| 519 |
+
# Parallel in (QKV + MLP in) projection
|
| 520 |
+
hidden_states = attn.to_qkv_mlp_proj(hidden_states)
|
| 521 |
+
qkv, mlp_hidden_states = torch.split(
|
| 522 |
+
hidden_states, [3 * attn.inner_dim, attn.mlp_hidden_dim * attn.mlp_mult_factor], dim=-1
|
| 523 |
+
)
|
| 524 |
+
|
| 525 |
+
# Handle the attention logic
|
| 526 |
+
query, key, value = qkv.chunk(3, dim=-1)
|
| 527 |
+
|
| 528 |
+
query = query.unflatten(-1, (attn.heads, -1))
|
| 529 |
+
key = key.unflatten(-1, (attn.heads, -1))
|
| 530 |
+
value = value.unflatten(-1, (attn.heads, -1))
|
| 531 |
+
|
| 532 |
+
query = attn.norm_q(query)
|
| 533 |
+
key = attn.norm_k(key)
|
| 534 |
+
|
| 535 |
+
if image_rotary_emb is not None:
|
| 536 |
+
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
|
| 537 |
+
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
|
| 538 |
+
|
| 539 |
+
hidden_states = attention_forward(
|
| 540 |
+
query,
|
| 541 |
+
key,
|
| 542 |
+
value,
|
| 543 |
+
q_pattern="b s n d", k_pattern="b s n d", v_pattern="b s n d", out_pattern="b s n d",
|
| 544 |
+
)
|
| 545 |
+
hidden_states = hidden_states.flatten(2, 3)
|
| 546 |
+
hidden_states = hidden_states.to(query.dtype)
|
| 547 |
+
|
| 548 |
+
# Handle the feedforward (FF) logic
|
| 549 |
+
mlp_hidden_states = attn.mlp_act_fn(mlp_hidden_states)
|
| 550 |
+
|
| 551 |
+
# Concatenate and parallel output projection
|
| 552 |
+
hidden_states = torch.cat([hidden_states, mlp_hidden_states], dim=-1)
|
| 553 |
+
hidden_states = attn.to_out(hidden_states)
|
| 554 |
+
|
| 555 |
+
return hidden_states
|
| 556 |
+
|
| 557 |
+
|
| 558 |
+
class Flux2ParallelSelfAttention(torch.nn.Module):
|
| 559 |
+
"""
|
| 560 |
+
Flux 2 parallel self-attention for the Flux 2 single-stream transformer blocks.
|
| 561 |
+
|
| 562 |
+
This implements a parallel transformer block, where the attention QKV projections are fused to the feedforward (FF)
|
| 563 |
+
input projections, and the attention output projections are fused to the FF output projections. See the [ViT-22B
|
| 564 |
+
paper](https://arxiv.org/abs/2302.05442) for a visual depiction of this type of transformer block.
|
| 565 |
+
"""
|
| 566 |
+
|
| 567 |
+
_default_processor_cls = Flux2ParallelSelfAttnProcessor
|
| 568 |
+
_available_processors = [Flux2ParallelSelfAttnProcessor]
|
| 569 |
+
# Does not support QKV fusion as the QKV projections are always fused
|
| 570 |
+
_supports_qkv_fusion = False
|
| 571 |
+
|
| 572 |
+
def __init__(
|
| 573 |
+
self,
|
| 574 |
+
query_dim: int,
|
| 575 |
+
heads: int = 8,
|
| 576 |
+
dim_head: int = 64,
|
| 577 |
+
dropout: float = 0.0,
|
| 578 |
+
bias: bool = False,
|
| 579 |
+
out_bias: bool = True,
|
| 580 |
+
eps: float = 1e-5,
|
| 581 |
+
out_dim: int = None,
|
| 582 |
+
elementwise_affine: bool = True,
|
| 583 |
+
mlp_ratio: float = 4.0,
|
| 584 |
+
mlp_mult_factor: int = 2,
|
| 585 |
+
processor=None,
|
| 586 |
+
):
|
| 587 |
+
super().__init__()
|
| 588 |
+
|
| 589 |
+
self.head_dim = dim_head
|
| 590 |
+
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
| 591 |
+
self.query_dim = query_dim
|
| 592 |
+
self.out_dim = out_dim if out_dim is not None else query_dim
|
| 593 |
+
self.heads = out_dim // dim_head if out_dim is not None else heads
|
| 594 |
+
|
| 595 |
+
self.use_bias = bias
|
| 596 |
+
self.dropout = dropout
|
| 597 |
+
|
| 598 |
+
self.mlp_ratio = mlp_ratio
|
| 599 |
+
self.mlp_hidden_dim = int(query_dim * self.mlp_ratio)
|
| 600 |
+
self.mlp_mult_factor = mlp_mult_factor
|
| 601 |
+
|
| 602 |
+
# Fused QKV projections + MLP input projection
|
| 603 |
+
self.to_qkv_mlp_proj = torch.nn.Linear(
|
| 604 |
+
self.query_dim, self.inner_dim * 3 + self.mlp_hidden_dim * self.mlp_mult_factor, bias=bias
|
| 605 |
+
)
|
| 606 |
+
self.mlp_act_fn = Flux2SwiGLU()
|
| 607 |
+
|
| 608 |
+
# QK Norm
|
| 609 |
+
self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
| 610 |
+
self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
| 611 |
+
|
| 612 |
+
# Fused attention output projection + MLP output projection
|
| 613 |
+
self.to_out = torch.nn.Linear(self.inner_dim + self.mlp_hidden_dim, self.out_dim, bias=out_bias)
|
| 614 |
+
|
| 615 |
+
if processor is None:
|
| 616 |
+
processor = self._default_processor_cls()
|
| 617 |
+
self.processor = processor
|
| 618 |
+
|
| 619 |
+
def forward(
|
| 620 |
+
self,
|
| 621 |
+
hidden_states: torch.Tensor,
|
| 622 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 623 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
| 624 |
+
**kwargs,
|
| 625 |
+
) -> torch.Tensor:
|
| 626 |
+
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
|
| 627 |
+
kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
|
| 628 |
+
return self.processor(self, hidden_states, attention_mask, image_rotary_emb, **kwargs)
|
| 629 |
+
|
| 630 |
+
|
| 631 |
+
class Flux2SingleTransformerBlock(nn.Module):
|
| 632 |
+
def __init__(
|
| 633 |
+
self,
|
| 634 |
+
dim: int,
|
| 635 |
+
num_attention_heads: int,
|
| 636 |
+
attention_head_dim: int,
|
| 637 |
+
mlp_ratio: float = 3.0,
|
| 638 |
+
eps: float = 1e-6,
|
| 639 |
+
bias: bool = False,
|
| 640 |
+
):
|
| 641 |
+
super().__init__()
|
| 642 |
+
|
| 643 |
+
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
| 644 |
+
|
| 645 |
+
# Note that the MLP in/out linear layers are fused with the attention QKV/out projections, respectively; this
|
| 646 |
+
# is often called a "parallel" transformer block. See the [ViT-22B paper](https://arxiv.org/abs/2302.05442)
|
| 647 |
+
# for a visual depiction of this type of transformer block.
|
| 648 |
+
self.attn = Flux2ParallelSelfAttention(
|
| 649 |
+
query_dim=dim,
|
| 650 |
+
dim_head=attention_head_dim,
|
| 651 |
+
heads=num_attention_heads,
|
| 652 |
+
out_dim=dim,
|
| 653 |
+
bias=bias,
|
| 654 |
+
out_bias=bias,
|
| 655 |
+
eps=eps,
|
| 656 |
+
mlp_ratio=mlp_ratio,
|
| 657 |
+
mlp_mult_factor=2,
|
| 658 |
+
processor=Flux2ParallelSelfAttnProcessor(),
|
| 659 |
+
)
|
| 660 |
+
|
| 661 |
+
def forward(
|
| 662 |
+
self,
|
| 663 |
+
hidden_states: torch.Tensor,
|
| 664 |
+
encoder_hidden_states: Optional[torch.Tensor],
|
| 665 |
+
temb_mod_params: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
| 666 |
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 667 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 668 |
+
split_hidden_states: bool = False,
|
| 669 |
+
text_seq_len: Optional[int] = None,
|
| 670 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 671 |
+
# If encoder_hidden_states is None, hidden_states is assumed to have encoder_hidden_states already
|
| 672 |
+
# concatenated
|
| 673 |
+
if encoder_hidden_states is not None:
|
| 674 |
+
text_seq_len = encoder_hidden_states.shape[1]
|
| 675 |
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
| 676 |
+
|
| 677 |
+
mod_shift, mod_scale, mod_gate = temb_mod_params
|
| 678 |
+
|
| 679 |
+
norm_hidden_states = self.norm(hidden_states)
|
| 680 |
+
norm_hidden_states = (1 + mod_scale) * norm_hidden_states + mod_shift
|
| 681 |
+
|
| 682 |
+
joint_attention_kwargs = joint_attention_kwargs or {}
|
| 683 |
+
attn_output = self.attn(
|
| 684 |
+
hidden_states=norm_hidden_states,
|
| 685 |
+
image_rotary_emb=image_rotary_emb,
|
| 686 |
+
**joint_attention_kwargs,
|
| 687 |
+
)
|
| 688 |
+
|
| 689 |
+
hidden_states = hidden_states + mod_gate * attn_output
|
| 690 |
+
if hidden_states.dtype == torch.float16:
|
| 691 |
+
hidden_states = hidden_states.clip(-65504, 65504)
|
| 692 |
+
|
| 693 |
+
if split_hidden_states:
|
| 694 |
+
encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:]
|
| 695 |
+
return encoder_hidden_states, hidden_states
|
| 696 |
+
else:
|
| 697 |
+
return hidden_states
|
| 698 |
+
|
| 699 |
+
|
| 700 |
+
class Flux2TransformerBlock(nn.Module):
|
| 701 |
+
def __init__(
|
| 702 |
+
self,
|
| 703 |
+
dim: int,
|
| 704 |
+
num_attention_heads: int,
|
| 705 |
+
attention_head_dim: int,
|
| 706 |
+
mlp_ratio: float = 3.0,
|
| 707 |
+
eps: float = 1e-6,
|
| 708 |
+
bias: bool = False,
|
| 709 |
+
):
|
| 710 |
+
super().__init__()
|
| 711 |
+
self.mlp_hidden_dim = int(dim * mlp_ratio)
|
| 712 |
+
|
| 713 |
+
self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
| 714 |
+
self.norm1_context = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
| 715 |
+
|
| 716 |
+
self.attn = Flux2Attention(
|
| 717 |
+
query_dim=dim,
|
| 718 |
+
added_kv_proj_dim=dim,
|
| 719 |
+
dim_head=attention_head_dim,
|
| 720 |
+
heads=num_attention_heads,
|
| 721 |
+
out_dim=dim,
|
| 722 |
+
bias=bias,
|
| 723 |
+
added_proj_bias=bias,
|
| 724 |
+
out_bias=bias,
|
| 725 |
+
eps=eps,
|
| 726 |
+
processor=Flux2AttnProcessor(),
|
| 727 |
+
)
|
| 728 |
+
|
| 729 |
+
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
| 730 |
+
self.ff = Flux2FeedForward(dim=dim, dim_out=dim, mult=mlp_ratio, bias=bias)
|
| 731 |
+
|
| 732 |
+
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
| 733 |
+
self.ff_context = Flux2FeedForward(dim=dim, dim_out=dim, mult=mlp_ratio, bias=bias)
|
| 734 |
+
|
| 735 |
+
def forward(
|
| 736 |
+
self,
|
| 737 |
+
hidden_states: torch.Tensor,
|
| 738 |
+
encoder_hidden_states: torch.Tensor,
|
| 739 |
+
temb_mod_params_img: Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...],
|
| 740 |
+
temb_mod_params_txt: Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...],
|
| 741 |
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 742 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 743 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 744 |
+
joint_attention_kwargs = joint_attention_kwargs or {}
|
| 745 |
+
|
| 746 |
+
# Modulation parameters shape: [1, 1, self.dim]
|
| 747 |
+
(shift_msa, scale_msa, gate_msa), (shift_mlp, scale_mlp, gate_mlp) = temb_mod_params_img
|
| 748 |
+
(c_shift_msa, c_scale_msa, c_gate_msa), (c_shift_mlp, c_scale_mlp, c_gate_mlp) = temb_mod_params_txt
|
| 749 |
+
|
| 750 |
+
# Img stream
|
| 751 |
+
norm_hidden_states = self.norm1(hidden_states)
|
| 752 |
+
norm_hidden_states = (1 + scale_msa) * norm_hidden_states + shift_msa
|
| 753 |
+
|
| 754 |
+
# Conditioning txt stream
|
| 755 |
+
norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states)
|
| 756 |
+
norm_encoder_hidden_states = (1 + c_scale_msa) * norm_encoder_hidden_states + c_shift_msa
|
| 757 |
+
|
| 758 |
+
# Attention on concatenated img + txt stream
|
| 759 |
+
attention_outputs = self.attn(
|
| 760 |
+
hidden_states=norm_hidden_states,
|
| 761 |
+
encoder_hidden_states=norm_encoder_hidden_states,
|
| 762 |
+
image_rotary_emb=image_rotary_emb,
|
| 763 |
+
**joint_attention_kwargs,
|
| 764 |
+
)
|
| 765 |
+
|
| 766 |
+
attn_output, context_attn_output = attention_outputs
|
| 767 |
+
|
| 768 |
+
# Process attention outputs for the image stream (`hidden_states`).
|
| 769 |
+
attn_output = gate_msa * attn_output
|
| 770 |
+
hidden_states = hidden_states + attn_output
|
| 771 |
+
|
| 772 |
+
norm_hidden_states = self.norm2(hidden_states)
|
| 773 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
|
| 774 |
+
|
| 775 |
+
ff_output = self.ff(norm_hidden_states)
|
| 776 |
+
hidden_states = hidden_states + gate_mlp * ff_output
|
| 777 |
+
|
| 778 |
+
# Process attention outputs for the text stream (`encoder_hidden_states`).
|
| 779 |
+
context_attn_output = c_gate_msa * context_attn_output
|
| 780 |
+
encoder_hidden_states = encoder_hidden_states + context_attn_output
|
| 781 |
+
|
| 782 |
+
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
|
| 783 |
+
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp) + c_shift_mlp
|
| 784 |
+
|
| 785 |
+
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
| 786 |
+
encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output
|
| 787 |
+
if encoder_hidden_states.dtype == torch.float16:
|
| 788 |
+
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
|
| 789 |
+
|
| 790 |
+
return encoder_hidden_states, hidden_states
|
| 791 |
+
|
| 792 |
+
|
| 793 |
+
class Flux2PosEmbed(nn.Module):
|
| 794 |
+
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
|
| 795 |
+
def __init__(self, theta: int, axes_dim: List[int]):
|
| 796 |
+
super().__init__()
|
| 797 |
+
self.theta = theta
|
| 798 |
+
self.axes_dim = axes_dim
|
| 799 |
+
|
| 800 |
+
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
| 801 |
+
# Expected ids shape: [S, len(self.axes_dim)]
|
| 802 |
+
cos_out = []
|
| 803 |
+
sin_out = []
|
| 804 |
+
pos = ids.float()
|
| 805 |
+
is_mps = ids.device.type == "mps"
|
| 806 |
+
is_npu = ids.device.type == "npu"
|
| 807 |
+
freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
|
| 808 |
+
# Unlike Flux 1, loop over len(self.axes_dim) rather than ids.shape[-1]
|
| 809 |
+
for i in range(len(self.axes_dim)):
|
| 810 |
+
cos, sin = get_1d_rotary_pos_embed(
|
| 811 |
+
self.axes_dim[i],
|
| 812 |
+
pos[..., i],
|
| 813 |
+
theta=self.theta,
|
| 814 |
+
repeat_interleave_real=True,
|
| 815 |
+
use_real=True,
|
| 816 |
+
freqs_dtype=freqs_dtype,
|
| 817 |
+
)
|
| 818 |
+
cos_out.append(cos)
|
| 819 |
+
sin_out.append(sin)
|
| 820 |
+
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
|
| 821 |
+
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
|
| 822 |
+
return freqs_cos, freqs_sin
|
| 823 |
+
|
| 824 |
+
|
| 825 |
+
class Flux2TimestepGuidanceEmbeddings(nn.Module):
|
| 826 |
+
def __init__(self, in_channels: int = 256, embedding_dim: int = 6144, bias: bool = False):
|
| 827 |
+
super().__init__()
|
| 828 |
+
|
| 829 |
+
self.time_proj = Timesteps(num_channels=in_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
|
| 830 |
+
self.timestep_embedder = TimestepEmbedding(
|
| 831 |
+
in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias
|
| 832 |
+
)
|
| 833 |
+
|
| 834 |
+
self.guidance_embedder = TimestepEmbedding(
|
| 835 |
+
in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias
|
| 836 |
+
)
|
| 837 |
+
|
| 838 |
+
def forward(self, timestep: torch.Tensor, guidance: torch.Tensor) -> torch.Tensor:
|
| 839 |
+
timesteps_proj = self.time_proj(timestep)
|
| 840 |
+
timesteps_emb = self.timestep_embedder(timesteps_proj.to(timestep.dtype)) # (N, D)
|
| 841 |
+
|
| 842 |
+
guidance_proj = self.time_proj(guidance)
|
| 843 |
+
guidance_emb = self.guidance_embedder(guidance_proj.to(guidance.dtype)) # (N, D)
|
| 844 |
+
|
| 845 |
+
time_guidance_emb = timesteps_emb + guidance_emb
|
| 846 |
+
|
| 847 |
+
return time_guidance_emb
|
| 848 |
+
|
| 849 |
+
|
| 850 |
+
class Flux2Modulation(nn.Module):
|
| 851 |
+
def __init__(self, dim: int, mod_param_sets: int = 2, bias: bool = False):
|
| 852 |
+
super().__init__()
|
| 853 |
+
self.mod_param_sets = mod_param_sets
|
| 854 |
+
|
| 855 |
+
self.linear = nn.Linear(dim, dim * 3 * self.mod_param_sets, bias=bias)
|
| 856 |
+
self.act_fn = nn.SiLU()
|
| 857 |
+
|
| 858 |
+
def forward(self, temb: torch.Tensor) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...]:
|
| 859 |
+
mod = self.act_fn(temb)
|
| 860 |
+
mod = self.linear(mod)
|
| 861 |
+
|
| 862 |
+
if mod.ndim == 2:
|
| 863 |
+
mod = mod.unsqueeze(1)
|
| 864 |
+
mod_params = torch.chunk(mod, 3 * self.mod_param_sets, dim=-1)
|
| 865 |
+
# Return tuple of 3-tuples of modulation params shift/scale/gate
|
| 866 |
+
return tuple(mod_params[3 * i : 3 * (i + 1)] for i in range(self.mod_param_sets))
|
| 867 |
+
|
| 868 |
+
|
| 869 |
+
class Flux2DiT(torch.nn.Module):
|
| 870 |
+
def __init__(
|
| 871 |
+
self,
|
| 872 |
+
patch_size: int = 1,
|
| 873 |
+
in_channels: int = 128,
|
| 874 |
+
out_channels: Optional[int] = None,
|
| 875 |
+
num_layers: int = 8,
|
| 876 |
+
num_single_layers: int = 48,
|
| 877 |
+
attention_head_dim: int = 128,
|
| 878 |
+
num_attention_heads: int = 48,
|
| 879 |
+
joint_attention_dim: int = 15360,
|
| 880 |
+
timestep_guidance_channels: int = 256,
|
| 881 |
+
mlp_ratio: float = 3.0,
|
| 882 |
+
axes_dims_rope: Tuple[int, ...] = (32, 32, 32, 32),
|
| 883 |
+
rope_theta: int = 2000,
|
| 884 |
+
eps: float = 1e-6,
|
| 885 |
+
):
|
| 886 |
+
super().__init__()
|
| 887 |
+
self.out_channels = out_channels or in_channels
|
| 888 |
+
self.inner_dim = num_attention_heads * attention_head_dim
|
| 889 |
+
|
| 890 |
+
# 1. Sinusoidal positional embedding for RoPE on image and text tokens
|
| 891 |
+
self.pos_embed = Flux2PosEmbed(theta=rope_theta, axes_dim=axes_dims_rope)
|
| 892 |
+
|
| 893 |
+
# 2. Combined timestep + guidance embedding
|
| 894 |
+
self.time_guidance_embed = Flux2TimestepGuidanceEmbeddings(
|
| 895 |
+
in_channels=timestep_guidance_channels, embedding_dim=self.inner_dim, bias=False
|
| 896 |
+
)
|
| 897 |
+
|
| 898 |
+
# 3. Modulation (double stream and single stream blocks share modulation parameters, resp.)
|
| 899 |
+
# Two sets of shift/scale/gate modulation parameters for the double stream attn and FF sub-blocks
|
| 900 |
+
self.double_stream_modulation_img = Flux2Modulation(self.inner_dim, mod_param_sets=2, bias=False)
|
| 901 |
+
self.double_stream_modulation_txt = Flux2Modulation(self.inner_dim, mod_param_sets=2, bias=False)
|
| 902 |
+
# Only one set of modulation parameters as the attn and FF sub-blocks are run in parallel for single stream
|
| 903 |
+
self.single_stream_modulation = Flux2Modulation(self.inner_dim, mod_param_sets=1, bias=False)
|
| 904 |
+
|
| 905 |
+
# 4. Input projections
|
| 906 |
+
self.x_embedder = nn.Linear(in_channels, self.inner_dim, bias=False)
|
| 907 |
+
self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim, bias=False)
|
| 908 |
+
|
| 909 |
+
# 5. Double Stream Transformer Blocks
|
| 910 |
+
self.transformer_blocks = nn.ModuleList(
|
| 911 |
+
[
|
| 912 |
+
Flux2TransformerBlock(
|
| 913 |
+
dim=self.inner_dim,
|
| 914 |
+
num_attention_heads=num_attention_heads,
|
| 915 |
+
attention_head_dim=attention_head_dim,
|
| 916 |
+
mlp_ratio=mlp_ratio,
|
| 917 |
+
eps=eps,
|
| 918 |
+
bias=False,
|
| 919 |
+
)
|
| 920 |
+
for _ in range(num_layers)
|
| 921 |
+
]
|
| 922 |
+
)
|
| 923 |
+
|
| 924 |
+
# 6. Single Stream Transformer Blocks
|
| 925 |
+
self.single_transformer_blocks = nn.ModuleList(
|
| 926 |
+
[
|
| 927 |
+
Flux2SingleTransformerBlock(
|
| 928 |
+
dim=self.inner_dim,
|
| 929 |
+
num_attention_heads=num_attention_heads,
|
| 930 |
+
attention_head_dim=attention_head_dim,
|
| 931 |
+
mlp_ratio=mlp_ratio,
|
| 932 |
+
eps=eps,
|
| 933 |
+
bias=False,
|
| 934 |
+
)
|
| 935 |
+
for _ in range(num_single_layers)
|
| 936 |
+
]
|
| 937 |
+
)
|
| 938 |
+
|
| 939 |
+
# 7. Output layers
|
| 940 |
+
self.norm_out = AdaLayerNormContinuous(
|
| 941 |
+
self.inner_dim, self.inner_dim, elementwise_affine=False, eps=eps, bias=False
|
| 942 |
+
)
|
| 943 |
+
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=False)
|
| 944 |
+
|
| 945 |
+
self.gradient_checkpointing = False
|
| 946 |
+
|
| 947 |
+
def forward(
|
| 948 |
+
self,
|
| 949 |
+
hidden_states: torch.Tensor,
|
| 950 |
+
encoder_hidden_states: torch.Tensor = None,
|
| 951 |
+
timestep: torch.LongTensor = None,
|
| 952 |
+
img_ids: torch.Tensor = None,
|
| 953 |
+
txt_ids: torch.Tensor = None,
|
| 954 |
+
guidance: torch.Tensor = None,
|
| 955 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 956 |
+
return_dict: bool = True,
|
| 957 |
+
use_gradient_checkpointing=False,
|
| 958 |
+
use_gradient_checkpointing_offload=False,
|
| 959 |
+
) -> Union[torch.Tensor]:
|
| 960 |
+
"""
|
| 961 |
+
The [`FluxTransformer2DModel`] forward method.
|
| 962 |
+
|
| 963 |
+
Args:
|
| 964 |
+
hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
|
| 965 |
+
Input `hidden_states`.
|
| 966 |
+
encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
|
| 967 |
+
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
| 968 |
+
timestep ( `torch.LongTensor`):
|
| 969 |
+
Used to indicate denoising step.
|
| 970 |
+
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
|
| 971 |
+
A list of tensors that if specified are added to the residuals of transformer blocks.
|
| 972 |
+
joint_attention_kwargs (`dict`, *optional*):
|
| 973 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 974 |
+
`self.processor` in
|
| 975 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 976 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 977 |
+
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
| 978 |
+
tuple.
|
| 979 |
+
|
| 980 |
+
Returns:
|
| 981 |
+
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
| 982 |
+
`tuple` where the first element is the sample tensor.
|
| 983 |
+
"""
|
| 984 |
+
# 0. Handle input arguments
|
| 985 |
+
if joint_attention_kwargs is not None:
|
| 986 |
+
joint_attention_kwargs = joint_attention_kwargs.copy()
|
| 987 |
+
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
| 988 |
+
else:
|
| 989 |
+
lora_scale = 1.0
|
| 990 |
+
|
| 991 |
+
num_txt_tokens = encoder_hidden_states.shape[1]
|
| 992 |
+
|
| 993 |
+
# 1. Calculate timestep embedding and modulation parameters
|
| 994 |
+
timestep = timestep.to(hidden_states.dtype) * 1000
|
| 995 |
+
guidance = guidance.to(hidden_states.dtype) * 1000
|
| 996 |
+
|
| 997 |
+
temb = self.time_guidance_embed(timestep, guidance)
|
| 998 |
+
|
| 999 |
+
double_stream_mod_img = self.double_stream_modulation_img(temb)
|
| 1000 |
+
double_stream_mod_txt = self.double_stream_modulation_txt(temb)
|
| 1001 |
+
single_stream_mod = self.single_stream_modulation(temb)[0]
|
| 1002 |
+
|
| 1003 |
+
# 2. Input projection for image (hidden_states) and conditioning text (encoder_hidden_states)
|
| 1004 |
+
hidden_states = self.x_embedder(hidden_states)
|
| 1005 |
+
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
| 1006 |
+
|
| 1007 |
+
# 3. Calculate RoPE embeddings from image and text tokens
|
| 1008 |
+
# NOTE: the below logic means that we can't support batched inference with images of different resolutions or
|
| 1009 |
+
# text prompts of differents lengths. Is this a use case we want to support?
|
| 1010 |
+
if img_ids.ndim == 3:
|
| 1011 |
+
img_ids = img_ids[0]
|
| 1012 |
+
if txt_ids.ndim == 3:
|
| 1013 |
+
txt_ids = txt_ids[0]
|
| 1014 |
+
|
| 1015 |
+
image_rotary_emb = self.pos_embed(img_ids)
|
| 1016 |
+
text_rotary_emb = self.pos_embed(txt_ids)
|
| 1017 |
+
concat_rotary_emb = (
|
| 1018 |
+
torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=0),
|
| 1019 |
+
torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0),
|
| 1020 |
+
)
|
| 1021 |
+
|
| 1022 |
+
# 4. Double Stream Transformer Blocks
|
| 1023 |
+
for index_block, block in enumerate(self.transformer_blocks):
|
| 1024 |
+
encoder_hidden_states, hidden_states = gradient_checkpoint_forward(
|
| 1025 |
+
block,
|
| 1026 |
+
use_gradient_checkpointing=use_gradient_checkpointing,
|
| 1027 |
+
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
| 1028 |
+
hidden_states=hidden_states,
|
| 1029 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 1030 |
+
temb_mod_params_img=double_stream_mod_img,
|
| 1031 |
+
temb_mod_params_txt=double_stream_mod_txt,
|
| 1032 |
+
image_rotary_emb=concat_rotary_emb,
|
| 1033 |
+
joint_attention_kwargs=joint_attention_kwargs,
|
| 1034 |
+
)
|
| 1035 |
+
# Concatenate text and image streams for single-block inference
|
| 1036 |
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
| 1037 |
+
|
| 1038 |
+
# 5. Single Stream Transformer Blocks
|
| 1039 |
+
for index_block, block in enumerate(self.single_transformer_blocks):
|
| 1040 |
+
hidden_states = gradient_checkpoint_forward(
|
| 1041 |
+
block,
|
| 1042 |
+
use_gradient_checkpointing=use_gradient_checkpointing,
|
| 1043 |
+
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
| 1044 |
+
hidden_states=hidden_states,
|
| 1045 |
+
encoder_hidden_states=None,
|
| 1046 |
+
temb_mod_params=single_stream_mod,
|
| 1047 |
+
image_rotary_emb=concat_rotary_emb,
|
| 1048 |
+
joint_attention_kwargs=joint_attention_kwargs,
|
| 1049 |
+
)
|
| 1050 |
+
# Remove text tokens from concatenated stream
|
| 1051 |
+
hidden_states = hidden_states[:, num_txt_tokens:, ...]
|
| 1052 |
+
|
| 1053 |
+
# 6. Output layers
|
| 1054 |
+
hidden_states = self.norm_out(hidden_states, temb)
|
| 1055 |
+
output = self.proj_out(hidden_states)
|
| 1056 |
+
|
| 1057 |
+
return output
|
diffsynth/models/flux2_text_encoder.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import Mistral3ForConditionalGeneration, Mistral3Config
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class Flux2TextEncoder(Mistral3ForConditionalGeneration):
|
| 5 |
+
def __init__(self):
|
| 6 |
+
config = Mistral3Config(**{
|
| 7 |
+
"architectures": [
|
| 8 |
+
"Mistral3ForConditionalGeneration"
|
| 9 |
+
],
|
| 10 |
+
"dtype": "bfloat16",
|
| 11 |
+
"image_token_index": 10,
|
| 12 |
+
"model_type": "mistral3",
|
| 13 |
+
"multimodal_projector_bias": False,
|
| 14 |
+
"projector_hidden_act": "gelu",
|
| 15 |
+
"spatial_merge_size": 2,
|
| 16 |
+
"text_config": {
|
| 17 |
+
"attention_dropout": 0.0,
|
| 18 |
+
"dtype": "bfloat16",
|
| 19 |
+
"head_dim": 128,
|
| 20 |
+
"hidden_act": "silu",
|
| 21 |
+
"hidden_size": 5120,
|
| 22 |
+
"initializer_range": 0.02,
|
| 23 |
+
"intermediate_size": 32768,
|
| 24 |
+
"max_position_embeddings": 131072,
|
| 25 |
+
"model_type": "mistral",
|
| 26 |
+
"num_attention_heads": 32,
|
| 27 |
+
"num_hidden_layers": 40,
|
| 28 |
+
"num_key_value_heads": 8,
|
| 29 |
+
"rms_norm_eps": 1e-05,
|
| 30 |
+
"rope_theta": 1000000000.0,
|
| 31 |
+
"sliding_window": None,
|
| 32 |
+
"use_cache": True,
|
| 33 |
+
"vocab_size": 131072
|
| 34 |
+
},
|
| 35 |
+
"transformers_version": "4.57.1",
|
| 36 |
+
"vision_config": {
|
| 37 |
+
"attention_dropout": 0.0,
|
| 38 |
+
"dtype": "bfloat16",
|
| 39 |
+
"head_dim": 64,
|
| 40 |
+
"hidden_act": "silu",
|
| 41 |
+
"hidden_size": 1024,
|
| 42 |
+
"image_size": 1540,
|
| 43 |
+
"initializer_range": 0.02,
|
| 44 |
+
"intermediate_size": 4096,
|
| 45 |
+
"model_type": "pixtral",
|
| 46 |
+
"num_attention_heads": 16,
|
| 47 |
+
"num_channels": 3,
|
| 48 |
+
"num_hidden_layers": 24,
|
| 49 |
+
"patch_size": 14,
|
| 50 |
+
"rope_theta": 10000.0
|
| 51 |
+
},
|
| 52 |
+
"vision_feature_layer": -1
|
| 53 |
+
})
|
| 54 |
+
super().__init__(config)
|
| 55 |
+
|
| 56 |
+
def forward(self, input_ids = None, pixel_values = None, attention_mask = None, position_ids = None, past_key_values = None, inputs_embeds = None, labels = None, use_cache = None, output_attentions = None, output_hidden_states = None, return_dict = None, cache_position = None, logits_to_keep = 0, image_sizes = None, **kwargs):
|
| 57 |
+
return super().forward(input_ids, pixel_values, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, logits_to_keep, image_sizes, **kwargs)
|
| 58 |
+
|
diffsynth/models/flux2_vae.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
diffsynth/models/flux_controlnet.py
ADDED
|
@@ -0,0 +1,384 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from einops import rearrange, repeat
|
| 3 |
+
from .flux_dit import RoPEEmbedding, TimestepEmbeddings, FluxJointTransformerBlock, FluxSingleTransformerBlock, RMSNorm
|
| 4 |
+
# from .utils import hash_state_dict_keys, init_weights_on_device
|
| 5 |
+
from contextlib import contextmanager
|
| 6 |
+
|
| 7 |
+
def hash_state_dict_keys(state_dict, with_shape=True):
|
| 8 |
+
keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
|
| 9 |
+
keys_str = keys_str.encode(encoding="UTF-8")
|
| 10 |
+
return hashlib.md5(keys_str).hexdigest()
|
| 11 |
+
|
| 12 |
+
@contextmanager
|
| 13 |
+
def init_weights_on_device(device = torch.device("meta"), include_buffers :bool = False):
|
| 14 |
+
|
| 15 |
+
old_register_parameter = torch.nn.Module.register_parameter
|
| 16 |
+
if include_buffers:
|
| 17 |
+
old_register_buffer = torch.nn.Module.register_buffer
|
| 18 |
+
|
| 19 |
+
def register_empty_parameter(module, name, param):
|
| 20 |
+
old_register_parameter(module, name, param)
|
| 21 |
+
if param is not None:
|
| 22 |
+
param_cls = type(module._parameters[name])
|
| 23 |
+
kwargs = module._parameters[name].__dict__
|
| 24 |
+
kwargs["requires_grad"] = param.requires_grad
|
| 25 |
+
module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
|
| 26 |
+
|
| 27 |
+
def register_empty_buffer(module, name, buffer, persistent=True):
|
| 28 |
+
old_register_buffer(module, name, buffer, persistent=persistent)
|
| 29 |
+
if buffer is not None:
|
| 30 |
+
module._buffers[name] = module._buffers[name].to(device)
|
| 31 |
+
|
| 32 |
+
def patch_tensor_constructor(fn):
|
| 33 |
+
def wrapper(*args, **kwargs):
|
| 34 |
+
kwargs["device"] = device
|
| 35 |
+
return fn(*args, **kwargs)
|
| 36 |
+
|
| 37 |
+
return wrapper
|
| 38 |
+
|
| 39 |
+
if include_buffers:
|
| 40 |
+
tensor_constructors_to_patch = {
|
| 41 |
+
torch_function_name: getattr(torch, torch_function_name)
|
| 42 |
+
for torch_function_name in ["empty", "zeros", "ones", "full"]
|
| 43 |
+
}
|
| 44 |
+
else:
|
| 45 |
+
tensor_constructors_to_patch = {}
|
| 46 |
+
|
| 47 |
+
try:
|
| 48 |
+
torch.nn.Module.register_parameter = register_empty_parameter
|
| 49 |
+
if include_buffers:
|
| 50 |
+
torch.nn.Module.register_buffer = register_empty_buffer
|
| 51 |
+
for torch_function_name in tensor_constructors_to_patch.keys():
|
| 52 |
+
setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
|
| 53 |
+
yield
|
| 54 |
+
finally:
|
| 55 |
+
torch.nn.Module.register_parameter = old_register_parameter
|
| 56 |
+
if include_buffers:
|
| 57 |
+
torch.nn.Module.register_buffer = old_register_buffer
|
| 58 |
+
for torch_function_name, old_torch_function in tensor_constructors_to_patch.items():
|
| 59 |
+
setattr(torch, torch_function_name, old_torch_function)
|
| 60 |
+
|
| 61 |
+
class FluxControlNet(torch.nn.Module):
|
| 62 |
+
def __init__(self, disable_guidance_embedder=False, num_joint_blocks=5, num_single_blocks=10, num_mode=0, mode_dict={}, additional_input_dim=0):
|
| 63 |
+
super().__init__()
|
| 64 |
+
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
|
| 65 |
+
self.time_embedder = TimestepEmbeddings(256, 3072)
|
| 66 |
+
self.guidance_embedder = None if disable_guidance_embedder else TimestepEmbeddings(256, 3072)
|
| 67 |
+
self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072))
|
| 68 |
+
self.context_embedder = torch.nn.Linear(4096, 3072)
|
| 69 |
+
self.x_embedder = torch.nn.Linear(64, 3072)
|
| 70 |
+
|
| 71 |
+
self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(num_joint_blocks)])
|
| 72 |
+
self.single_blocks = torch.nn.ModuleList([FluxSingleTransformerBlock(3072, 24) for _ in range(num_single_blocks)])
|
| 73 |
+
|
| 74 |
+
self.controlnet_blocks = torch.nn.ModuleList([torch.nn.Linear(3072, 3072) for _ in range(num_joint_blocks)])
|
| 75 |
+
self.controlnet_single_blocks = torch.nn.ModuleList([torch.nn.Linear(3072, 3072) for _ in range(num_single_blocks)])
|
| 76 |
+
|
| 77 |
+
self.mode_dict = mode_dict
|
| 78 |
+
self.controlnet_mode_embedder = torch.nn.Embedding(num_mode, 3072) if len(mode_dict) > 0 else None
|
| 79 |
+
self.controlnet_x_embedder = torch.nn.Linear(64 + additional_input_dim, 3072)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def prepare_image_ids(self, latents):
|
| 83 |
+
batch_size, _, height, width = latents.shape
|
| 84 |
+
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
|
| 85 |
+
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
|
| 86 |
+
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
|
| 87 |
+
|
| 88 |
+
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
| 89 |
+
|
| 90 |
+
latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
|
| 91 |
+
latent_image_ids = latent_image_ids.reshape(
|
| 92 |
+
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
| 93 |
+
)
|
| 94 |
+
latent_image_ids = latent_image_ids.to(device=latents.device, dtype=latents.dtype)
|
| 95 |
+
|
| 96 |
+
return latent_image_ids
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def patchify(self, hidden_states):
|
| 100 |
+
hidden_states = rearrange(hidden_states, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2)
|
| 101 |
+
return hidden_states
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def align_res_stack_to_original_blocks(self, res_stack, num_blocks, hidden_states):
|
| 105 |
+
if len(res_stack) == 0:
|
| 106 |
+
return [torch.zeros_like(hidden_states)] * num_blocks
|
| 107 |
+
interval = (num_blocks + len(res_stack) - 1) // len(res_stack)
|
| 108 |
+
aligned_res_stack = [res_stack[block_id // interval] for block_id in range(num_blocks)]
|
| 109 |
+
return aligned_res_stack
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def forward(
|
| 113 |
+
self,
|
| 114 |
+
hidden_states,
|
| 115 |
+
controlnet_conditioning,
|
| 116 |
+
timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None,
|
| 117 |
+
processor_id=None,
|
| 118 |
+
tiled=False, tile_size=128, tile_stride=64,
|
| 119 |
+
**kwargs
|
| 120 |
+
):
|
| 121 |
+
if image_ids is None:
|
| 122 |
+
image_ids = self.prepare_image_ids(hidden_states)
|
| 123 |
+
|
| 124 |
+
conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb)
|
| 125 |
+
if self.guidance_embedder is not None:
|
| 126 |
+
guidance = guidance * 1000
|
| 127 |
+
conditioning = conditioning + self.guidance_embedder(guidance, hidden_states.dtype)
|
| 128 |
+
prompt_emb = self.context_embedder(prompt_emb)
|
| 129 |
+
if self.controlnet_mode_embedder is not None: # Different from FluxDiT
|
| 130 |
+
processor_id = torch.tensor([self.mode_dict[processor_id]], dtype=torch.int)
|
| 131 |
+
processor_id = repeat(processor_id, "D -> B D", B=1).to(text_ids.device)
|
| 132 |
+
prompt_emb = torch.concat([self.controlnet_mode_embedder(processor_id), prompt_emb], dim=1)
|
| 133 |
+
text_ids = torch.cat([text_ids[:, :1], text_ids], dim=1)
|
| 134 |
+
image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
|
| 135 |
+
|
| 136 |
+
hidden_states = self.patchify(hidden_states)
|
| 137 |
+
hidden_states = self.x_embedder(hidden_states)
|
| 138 |
+
controlnet_conditioning = self.patchify(controlnet_conditioning) # Different from FluxDiT
|
| 139 |
+
hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_conditioning) # Different from FluxDiT
|
| 140 |
+
|
| 141 |
+
controlnet_res_stack = []
|
| 142 |
+
for block, controlnet_block in zip(self.blocks, self.controlnet_blocks):
|
| 143 |
+
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
|
| 144 |
+
controlnet_res_stack.append(controlnet_block(hidden_states))
|
| 145 |
+
|
| 146 |
+
controlnet_single_res_stack = []
|
| 147 |
+
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
|
| 148 |
+
for block, controlnet_block in zip(self.single_blocks, self.controlnet_single_blocks):
|
| 149 |
+
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
|
| 150 |
+
controlnet_single_res_stack.append(controlnet_block(hidden_states[:, prompt_emb.shape[1]:]))
|
| 151 |
+
|
| 152 |
+
controlnet_res_stack = self.align_res_stack_to_original_blocks(controlnet_res_stack, 19, hidden_states[:, prompt_emb.shape[1]:])
|
| 153 |
+
controlnet_single_res_stack = self.align_res_stack_to_original_blocks(controlnet_single_res_stack, 38, hidden_states[:, prompt_emb.shape[1]:])
|
| 154 |
+
|
| 155 |
+
return controlnet_res_stack, controlnet_single_res_stack
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
# @staticmethod
|
| 159 |
+
# def state_dict_converter():
|
| 160 |
+
# return FluxControlNetStateDictConverter()
|
| 161 |
+
|
| 162 |
+
def quantize(self):
|
| 163 |
+
def cast_to(weight, dtype=None, device=None, copy=False):
|
| 164 |
+
if device is None or weight.device == device:
|
| 165 |
+
if not copy:
|
| 166 |
+
if dtype is None or weight.dtype == dtype:
|
| 167 |
+
return weight
|
| 168 |
+
return weight.to(dtype=dtype, copy=copy)
|
| 169 |
+
|
| 170 |
+
r = torch.empty_like(weight, dtype=dtype, device=device)
|
| 171 |
+
r.copy_(weight)
|
| 172 |
+
return r
|
| 173 |
+
|
| 174 |
+
def cast_weight(s, input=None, dtype=None, device=None):
|
| 175 |
+
if input is not None:
|
| 176 |
+
if dtype is None:
|
| 177 |
+
dtype = input.dtype
|
| 178 |
+
if device is None:
|
| 179 |
+
device = input.device
|
| 180 |
+
weight = cast_to(s.weight, dtype, device)
|
| 181 |
+
return weight
|
| 182 |
+
|
| 183 |
+
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
| 184 |
+
if input is not None:
|
| 185 |
+
if dtype is None:
|
| 186 |
+
dtype = input.dtype
|
| 187 |
+
if bias_dtype is None:
|
| 188 |
+
bias_dtype = dtype
|
| 189 |
+
if device is None:
|
| 190 |
+
device = input.device
|
| 191 |
+
bias = None
|
| 192 |
+
weight = cast_to(s.weight, dtype, device)
|
| 193 |
+
bias = cast_to(s.bias, bias_dtype, device)
|
| 194 |
+
return weight, bias
|
| 195 |
+
|
| 196 |
+
class quantized_layer:
|
| 197 |
+
class QLinear(torch.nn.Linear):
|
| 198 |
+
def __init__(self, *args, **kwargs):
|
| 199 |
+
super().__init__(*args, **kwargs)
|
| 200 |
+
|
| 201 |
+
def forward(self,input,**kwargs):
|
| 202 |
+
weight,bias= cast_bias_weight(self,input)
|
| 203 |
+
return torch.nn.functional.linear(input,weight,bias)
|
| 204 |
+
|
| 205 |
+
class QRMSNorm(torch.nn.Module):
|
| 206 |
+
def __init__(self, module):
|
| 207 |
+
super().__init__()
|
| 208 |
+
self.module = module
|
| 209 |
+
|
| 210 |
+
def forward(self,hidden_states,**kwargs):
|
| 211 |
+
weight= cast_weight(self.module,hidden_states)
|
| 212 |
+
input_dtype = hidden_states.dtype
|
| 213 |
+
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
|
| 214 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.module.eps)
|
| 215 |
+
hidden_states = hidden_states.to(input_dtype) * weight
|
| 216 |
+
return hidden_states
|
| 217 |
+
|
| 218 |
+
class QEmbedding(torch.nn.Embedding):
|
| 219 |
+
def __init__(self, *args, **kwargs):
|
| 220 |
+
super().__init__(*args, **kwargs)
|
| 221 |
+
|
| 222 |
+
def forward(self,input,**kwargs):
|
| 223 |
+
weight= cast_weight(self,input)
|
| 224 |
+
return torch.nn.functional.embedding(
|
| 225 |
+
input, weight, self.padding_idx, self.max_norm,
|
| 226 |
+
self.norm_type, self.scale_grad_by_freq, self.sparse)
|
| 227 |
+
|
| 228 |
+
def replace_layer(model):
|
| 229 |
+
for name, module in model.named_children():
|
| 230 |
+
if isinstance(module,quantized_layer.QRMSNorm):
|
| 231 |
+
continue
|
| 232 |
+
if isinstance(module, torch.nn.Linear):
|
| 233 |
+
with init_weights_on_device():
|
| 234 |
+
new_layer = quantized_layer.QLinear(module.in_features,module.out_features)
|
| 235 |
+
new_layer.weight = module.weight
|
| 236 |
+
if module.bias is not None:
|
| 237 |
+
new_layer.bias = module.bias
|
| 238 |
+
setattr(model, name, new_layer)
|
| 239 |
+
elif isinstance(module, RMSNorm):
|
| 240 |
+
if hasattr(module,"quantized"):
|
| 241 |
+
continue
|
| 242 |
+
module.quantized= True
|
| 243 |
+
new_layer = quantized_layer.QRMSNorm(module)
|
| 244 |
+
setattr(model, name, new_layer)
|
| 245 |
+
elif isinstance(module,torch.nn.Embedding):
|
| 246 |
+
rows, cols = module.weight.shape
|
| 247 |
+
new_layer = quantized_layer.QEmbedding(
|
| 248 |
+
num_embeddings=rows,
|
| 249 |
+
embedding_dim=cols,
|
| 250 |
+
_weight=module.weight,
|
| 251 |
+
# _freeze=module.freeze,
|
| 252 |
+
padding_idx=module.padding_idx,
|
| 253 |
+
max_norm=module.max_norm,
|
| 254 |
+
norm_type=module.norm_type,
|
| 255 |
+
scale_grad_by_freq=module.scale_grad_by_freq,
|
| 256 |
+
sparse=module.sparse)
|
| 257 |
+
setattr(model, name, new_layer)
|
| 258 |
+
else:
|
| 259 |
+
replace_layer(module)
|
| 260 |
+
|
| 261 |
+
replace_layer(self)
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
class FluxControlNetStateDictConverter:
|
| 266 |
+
def __init__(self):
|
| 267 |
+
pass
|
| 268 |
+
|
| 269 |
+
def from_diffusers(self, state_dict):
|
| 270 |
+
hash_value = hash_state_dict_keys(state_dict)
|
| 271 |
+
global_rename_dict = {
|
| 272 |
+
"context_embedder": "context_embedder",
|
| 273 |
+
"x_embedder": "x_embedder",
|
| 274 |
+
"time_text_embed.timestep_embedder.linear_1": "time_embedder.timestep_embedder.0",
|
| 275 |
+
"time_text_embed.timestep_embedder.linear_2": "time_embedder.timestep_embedder.2",
|
| 276 |
+
"time_text_embed.guidance_embedder.linear_1": "guidance_embedder.timestep_embedder.0",
|
| 277 |
+
"time_text_embed.guidance_embedder.linear_2": "guidance_embedder.timestep_embedder.2",
|
| 278 |
+
"time_text_embed.text_embedder.linear_1": "pooled_text_embedder.0",
|
| 279 |
+
"time_text_embed.text_embedder.linear_2": "pooled_text_embedder.2",
|
| 280 |
+
"norm_out.linear": "final_norm_out.linear",
|
| 281 |
+
"proj_out": "final_proj_out",
|
| 282 |
+
}
|
| 283 |
+
rename_dict = {
|
| 284 |
+
"proj_out": "proj_out",
|
| 285 |
+
"norm1.linear": "norm1_a.linear",
|
| 286 |
+
"norm1_context.linear": "norm1_b.linear",
|
| 287 |
+
"attn.to_q": "attn.a_to_q",
|
| 288 |
+
"attn.to_k": "attn.a_to_k",
|
| 289 |
+
"attn.to_v": "attn.a_to_v",
|
| 290 |
+
"attn.to_out.0": "attn.a_to_out",
|
| 291 |
+
"attn.add_q_proj": "attn.b_to_q",
|
| 292 |
+
"attn.add_k_proj": "attn.b_to_k",
|
| 293 |
+
"attn.add_v_proj": "attn.b_to_v",
|
| 294 |
+
"attn.to_add_out": "attn.b_to_out",
|
| 295 |
+
"ff.net.0.proj": "ff_a.0",
|
| 296 |
+
"ff.net.2": "ff_a.2",
|
| 297 |
+
"ff_context.net.0.proj": "ff_b.0",
|
| 298 |
+
"ff_context.net.2": "ff_b.2",
|
| 299 |
+
"attn.norm_q": "attn.norm_q_a",
|
| 300 |
+
"attn.norm_k": "attn.norm_k_a",
|
| 301 |
+
"attn.norm_added_q": "attn.norm_q_b",
|
| 302 |
+
"attn.norm_added_k": "attn.norm_k_b",
|
| 303 |
+
}
|
| 304 |
+
rename_dict_single = {
|
| 305 |
+
"attn.to_q": "a_to_q",
|
| 306 |
+
"attn.to_k": "a_to_k",
|
| 307 |
+
"attn.to_v": "a_to_v",
|
| 308 |
+
"attn.norm_q": "norm_q_a",
|
| 309 |
+
"attn.norm_k": "norm_k_a",
|
| 310 |
+
"norm.linear": "norm.linear",
|
| 311 |
+
"proj_mlp": "proj_in_besides_attn",
|
| 312 |
+
"proj_out": "proj_out",
|
| 313 |
+
}
|
| 314 |
+
state_dict_ = {}
|
| 315 |
+
for name, param in state_dict.items():
|
| 316 |
+
if name.endswith(".weight") or name.endswith(".bias"):
|
| 317 |
+
suffix = ".weight" if name.endswith(".weight") else ".bias"
|
| 318 |
+
prefix = name[:-len(suffix)]
|
| 319 |
+
if prefix in global_rename_dict:
|
| 320 |
+
state_dict_[global_rename_dict[prefix] + suffix] = param
|
| 321 |
+
elif prefix.startswith("transformer_blocks."):
|
| 322 |
+
names = prefix.split(".")
|
| 323 |
+
names[0] = "blocks"
|
| 324 |
+
middle = ".".join(names[2:])
|
| 325 |
+
if middle in rename_dict:
|
| 326 |
+
name_ = ".".join(names[:2] + [rename_dict[middle]] + [suffix[1:]])
|
| 327 |
+
state_dict_[name_] = param
|
| 328 |
+
elif prefix.startswith("single_transformer_blocks."):
|
| 329 |
+
names = prefix.split(".")
|
| 330 |
+
names[0] = "single_blocks"
|
| 331 |
+
middle = ".".join(names[2:])
|
| 332 |
+
if middle in rename_dict_single:
|
| 333 |
+
name_ = ".".join(names[:2] + [rename_dict_single[middle]] + [suffix[1:]])
|
| 334 |
+
state_dict_[name_] = param
|
| 335 |
+
else:
|
| 336 |
+
state_dict_[name] = param
|
| 337 |
+
else:
|
| 338 |
+
state_dict_[name] = param
|
| 339 |
+
for name in list(state_dict_.keys()):
|
| 340 |
+
if ".proj_in_besides_attn." in name:
|
| 341 |
+
name_ = name.replace(".proj_in_besides_attn.", ".to_qkv_mlp.")
|
| 342 |
+
param = torch.concat([
|
| 343 |
+
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_q.")],
|
| 344 |
+
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_k.")],
|
| 345 |
+
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_v.")],
|
| 346 |
+
state_dict_[name],
|
| 347 |
+
], dim=0)
|
| 348 |
+
state_dict_[name_] = param
|
| 349 |
+
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_q."))
|
| 350 |
+
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_k."))
|
| 351 |
+
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_v."))
|
| 352 |
+
state_dict_.pop(name)
|
| 353 |
+
for name in list(state_dict_.keys()):
|
| 354 |
+
for component in ["a", "b"]:
|
| 355 |
+
if f".{component}_to_q." in name:
|
| 356 |
+
name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.")
|
| 357 |
+
param = torch.concat([
|
| 358 |
+
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
|
| 359 |
+
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
|
| 360 |
+
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
|
| 361 |
+
], dim=0)
|
| 362 |
+
state_dict_[name_] = param
|
| 363 |
+
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q."))
|
| 364 |
+
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k."))
|
| 365 |
+
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v."))
|
| 366 |
+
if hash_value == "78d18b9101345ff695f312e7e62538c0":
|
| 367 |
+
extra_kwargs = {"num_mode": 10, "mode_dict": {"canny": 0, "tile": 1, "depth": 2, "blur": 3, "pose": 4, "gray": 5, "lq": 6}}
|
| 368 |
+
elif hash_value == "b001c89139b5f053c715fe772362dd2a":
|
| 369 |
+
extra_kwargs = {"num_single_blocks": 0}
|
| 370 |
+
elif hash_value == "52357cb26250681367488a8954c271e8":
|
| 371 |
+
extra_kwargs = {"num_joint_blocks": 6, "num_single_blocks": 0, "additional_input_dim": 4}
|
| 372 |
+
elif hash_value == "0cfd1740758423a2a854d67c136d1e8c":
|
| 373 |
+
extra_kwargs = {"num_joint_blocks": 4, "num_single_blocks": 1}
|
| 374 |
+
elif hash_value == "7f9583eb8ba86642abb9a21a4b2c9e16":
|
| 375 |
+
extra_kwargs = {"num_joint_blocks": 4, "num_single_blocks": 10}
|
| 376 |
+
elif hash_value == "43ad5aaa27dd4ee01b832ed16773fa52":
|
| 377 |
+
extra_kwargs = {"num_joint_blocks": 6, "num_single_blocks": 0}
|
| 378 |
+
else:
|
| 379 |
+
extra_kwargs = {}
|
| 380 |
+
return state_dict_, extra_kwargs
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
def from_civitai(self, state_dict):
|
| 384 |
+
return self.from_diffusers(state_dict)
|
diffsynth/models/flux_dit.py
ADDED
|
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from .general_modules import TimestepEmbeddings, AdaLayerNorm, RMSNorm
|
| 3 |
+
from einops import rearrange
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def interact_with_ipadapter(hidden_states, q, ip_k, ip_v, scale=1.0):
|
| 7 |
+
batch_size, num_tokens = hidden_states.shape[0:2]
|
| 8 |
+
ip_hidden_states = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v)
|
| 9 |
+
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, num_tokens, -1)
|
| 10 |
+
hidden_states = hidden_states + scale * ip_hidden_states
|
| 11 |
+
return hidden_states
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class RoPEEmbedding(torch.nn.Module):
|
| 15 |
+
def __init__(self, dim, theta, axes_dim):
|
| 16 |
+
super().__init__()
|
| 17 |
+
self.dim = dim
|
| 18 |
+
self.theta = theta
|
| 19 |
+
self.axes_dim = axes_dim
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def rope(self, pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
|
| 23 |
+
assert dim % 2 == 0, "The dimension must be even."
|
| 24 |
+
|
| 25 |
+
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
| 26 |
+
omega = 1.0 / (theta**scale)
|
| 27 |
+
|
| 28 |
+
batch_size, seq_length = pos.shape
|
| 29 |
+
out = torch.einsum("...n,d->...nd", pos, omega)
|
| 30 |
+
cos_out = torch.cos(out)
|
| 31 |
+
sin_out = torch.sin(out)
|
| 32 |
+
|
| 33 |
+
stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
|
| 34 |
+
out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
|
| 35 |
+
return out.float()
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def forward(self, ids):
|
| 39 |
+
n_axes = ids.shape[-1]
|
| 40 |
+
emb = torch.cat([self.rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3)
|
| 41 |
+
return emb.unsqueeze(1)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class FluxJointAttention(torch.nn.Module):
|
| 46 |
+
def __init__(self, dim_a, dim_b, num_heads, head_dim, only_out_a=False):
|
| 47 |
+
super().__init__()
|
| 48 |
+
self.num_heads = num_heads
|
| 49 |
+
self.head_dim = head_dim
|
| 50 |
+
self.only_out_a = only_out_a
|
| 51 |
+
|
| 52 |
+
self.a_to_qkv = torch.nn.Linear(dim_a, dim_a * 3)
|
| 53 |
+
self.b_to_qkv = torch.nn.Linear(dim_b, dim_b * 3)
|
| 54 |
+
|
| 55 |
+
self.norm_q_a = RMSNorm(head_dim, eps=1e-6)
|
| 56 |
+
self.norm_k_a = RMSNorm(head_dim, eps=1e-6)
|
| 57 |
+
self.norm_q_b = RMSNorm(head_dim, eps=1e-6)
|
| 58 |
+
self.norm_k_b = RMSNorm(head_dim, eps=1e-6)
|
| 59 |
+
|
| 60 |
+
self.a_to_out = torch.nn.Linear(dim_a, dim_a)
|
| 61 |
+
if not only_out_a:
|
| 62 |
+
self.b_to_out = torch.nn.Linear(dim_b, dim_b)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def apply_rope(self, xq, xk, freqs_cis):
|
| 66 |
+
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
| 67 |
+
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
| 68 |
+
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
| 69 |
+
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
| 70 |
+
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
| 71 |
+
|
| 72 |
+
def forward(self, hidden_states_a, hidden_states_b, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
|
| 73 |
+
batch_size = hidden_states_a.shape[0]
|
| 74 |
+
|
| 75 |
+
# Part A
|
| 76 |
+
qkv_a = self.a_to_qkv(hidden_states_a)
|
| 77 |
+
qkv_a = qkv_a.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
|
| 78 |
+
q_a, k_a, v_a = qkv_a.chunk(3, dim=1)
|
| 79 |
+
q_a, k_a = self.norm_q_a(q_a), self.norm_k_a(k_a)
|
| 80 |
+
|
| 81 |
+
# Part B
|
| 82 |
+
qkv_b = self.b_to_qkv(hidden_states_b)
|
| 83 |
+
qkv_b = qkv_b.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
|
| 84 |
+
q_b, k_b, v_b = qkv_b.chunk(3, dim=1)
|
| 85 |
+
q_b, k_b = self.norm_q_b(q_b), self.norm_k_b(k_b)
|
| 86 |
+
|
| 87 |
+
q = torch.concat([q_b, q_a], dim=2)
|
| 88 |
+
k = torch.concat([k_b, k_a], dim=2)
|
| 89 |
+
v = torch.concat([v_b, v_a], dim=2)
|
| 90 |
+
|
| 91 |
+
q, k = self.apply_rope(q, k, image_rotary_emb)
|
| 92 |
+
|
| 93 |
+
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
| 94 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
| 95 |
+
hidden_states = hidden_states.to(q.dtype)
|
| 96 |
+
hidden_states_b, hidden_states_a = hidden_states[:, :hidden_states_b.shape[1]], hidden_states[:, hidden_states_b.shape[1]:]
|
| 97 |
+
if ipadapter_kwargs_list is not None:
|
| 98 |
+
hidden_states_a = interact_with_ipadapter(hidden_states_a, q_a, **ipadapter_kwargs_list)
|
| 99 |
+
hidden_states_a = self.a_to_out(hidden_states_a)
|
| 100 |
+
if self.only_out_a:
|
| 101 |
+
return hidden_states_a
|
| 102 |
+
else:
|
| 103 |
+
hidden_states_b = self.b_to_out(hidden_states_b)
|
| 104 |
+
return hidden_states_a, hidden_states_b
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class FluxJointTransformerBlock(torch.nn.Module):
|
| 109 |
+
def __init__(self, dim, num_attention_heads):
|
| 110 |
+
super().__init__()
|
| 111 |
+
self.norm1_a = AdaLayerNorm(dim)
|
| 112 |
+
self.norm1_b = AdaLayerNorm(dim)
|
| 113 |
+
|
| 114 |
+
self.attn = FluxJointAttention(dim, dim, num_attention_heads, dim // num_attention_heads)
|
| 115 |
+
|
| 116 |
+
self.norm2_a = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
| 117 |
+
self.ff_a = torch.nn.Sequential(
|
| 118 |
+
torch.nn.Linear(dim, dim*4),
|
| 119 |
+
torch.nn.GELU(approximate="tanh"),
|
| 120 |
+
torch.nn.Linear(dim*4, dim)
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
self.norm2_b = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
| 124 |
+
self.ff_b = torch.nn.Sequential(
|
| 125 |
+
torch.nn.Linear(dim, dim*4),
|
| 126 |
+
torch.nn.GELU(approximate="tanh"),
|
| 127 |
+
torch.nn.Linear(dim*4, dim)
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
|
| 132 |
+
norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb)
|
| 133 |
+
norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb)
|
| 134 |
+
|
| 135 |
+
# Attention
|
| 136 |
+
attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b, image_rotary_emb, attn_mask, ipadapter_kwargs_list)
|
| 137 |
+
|
| 138 |
+
# Part A
|
| 139 |
+
hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
|
| 140 |
+
norm_hidden_states_a = self.norm2_a(hidden_states_a) * (1 + scale_mlp_a) + shift_mlp_a
|
| 141 |
+
hidden_states_a = hidden_states_a + gate_mlp_a * self.ff_a(norm_hidden_states_a)
|
| 142 |
+
|
| 143 |
+
# Part B
|
| 144 |
+
hidden_states_b = hidden_states_b + gate_msa_b * attn_output_b
|
| 145 |
+
norm_hidden_states_b = self.norm2_b(hidden_states_b) * (1 + scale_mlp_b) + shift_mlp_b
|
| 146 |
+
hidden_states_b = hidden_states_b + gate_mlp_b * self.ff_b(norm_hidden_states_b)
|
| 147 |
+
|
| 148 |
+
return hidden_states_a, hidden_states_b
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class FluxSingleAttention(torch.nn.Module):
|
| 153 |
+
def __init__(self, dim_a, dim_b, num_heads, head_dim):
|
| 154 |
+
super().__init__()
|
| 155 |
+
self.num_heads = num_heads
|
| 156 |
+
self.head_dim = head_dim
|
| 157 |
+
|
| 158 |
+
self.a_to_qkv = torch.nn.Linear(dim_a, dim_a * 3)
|
| 159 |
+
|
| 160 |
+
self.norm_q_a = RMSNorm(head_dim, eps=1e-6)
|
| 161 |
+
self.norm_k_a = RMSNorm(head_dim, eps=1e-6)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def apply_rope(self, xq, xk, freqs_cis):
|
| 165 |
+
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
| 166 |
+
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
| 167 |
+
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
| 168 |
+
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
| 169 |
+
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def forward(self, hidden_states, image_rotary_emb):
|
| 173 |
+
batch_size = hidden_states.shape[0]
|
| 174 |
+
|
| 175 |
+
qkv_a = self.a_to_qkv(hidden_states)
|
| 176 |
+
qkv_a = qkv_a.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
|
| 177 |
+
q_a, k_a, v = qkv_a.chunk(3, dim=1)
|
| 178 |
+
q_a, k_a = self.norm_q_a(q_a), self.norm_k_a(k_a)
|
| 179 |
+
|
| 180 |
+
q, k = self.apply_rope(q_a, k_a, image_rotary_emb)
|
| 181 |
+
|
| 182 |
+
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
| 183 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
| 184 |
+
hidden_states = hidden_states.to(q.dtype)
|
| 185 |
+
return hidden_states
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
class AdaLayerNormSingle(torch.nn.Module):
|
| 190 |
+
def __init__(self, dim):
|
| 191 |
+
super().__init__()
|
| 192 |
+
self.silu = torch.nn.SiLU()
|
| 193 |
+
self.linear = torch.nn.Linear(dim, 3 * dim, bias=True)
|
| 194 |
+
self.norm = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def forward(self, x, emb):
|
| 198 |
+
emb = self.linear(self.silu(emb))
|
| 199 |
+
shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1)
|
| 200 |
+
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
| 201 |
+
return x, gate_msa
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
class FluxSingleTransformerBlock(torch.nn.Module):
|
| 206 |
+
def __init__(self, dim, num_attention_heads):
|
| 207 |
+
super().__init__()
|
| 208 |
+
self.num_heads = num_attention_heads
|
| 209 |
+
self.head_dim = dim // num_attention_heads
|
| 210 |
+
self.dim = dim
|
| 211 |
+
|
| 212 |
+
self.norm = AdaLayerNormSingle(dim)
|
| 213 |
+
self.to_qkv_mlp = torch.nn.Linear(dim, dim * (3 + 4))
|
| 214 |
+
self.norm_q_a = RMSNorm(self.head_dim, eps=1e-6)
|
| 215 |
+
self.norm_k_a = RMSNorm(self.head_dim, eps=1e-6)
|
| 216 |
+
|
| 217 |
+
self.proj_out = torch.nn.Linear(dim * 5, dim)
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def apply_rope(self, xq, xk, freqs_cis):
|
| 221 |
+
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
| 222 |
+
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
| 223 |
+
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
| 224 |
+
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
| 225 |
+
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def process_attention(self, hidden_states, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
|
| 229 |
+
batch_size = hidden_states.shape[0]
|
| 230 |
+
|
| 231 |
+
qkv = hidden_states.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
|
| 232 |
+
q, k, v = qkv.chunk(3, dim=1)
|
| 233 |
+
q, k = self.norm_q_a(q), self.norm_k_a(k)
|
| 234 |
+
|
| 235 |
+
q, k = self.apply_rope(q, k, image_rotary_emb)
|
| 236 |
+
|
| 237 |
+
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
| 238 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
| 239 |
+
hidden_states = hidden_states.to(q.dtype)
|
| 240 |
+
if ipadapter_kwargs_list is not None:
|
| 241 |
+
hidden_states = interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs_list)
|
| 242 |
+
return hidden_states
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
|
| 246 |
+
residual = hidden_states_a
|
| 247 |
+
norm_hidden_states, gate = self.norm(hidden_states_a, emb=temb)
|
| 248 |
+
hidden_states_a = self.to_qkv_mlp(norm_hidden_states)
|
| 249 |
+
attn_output, mlp_hidden_states = hidden_states_a[:, :, :self.dim * 3], hidden_states_a[:, :, self.dim * 3:]
|
| 250 |
+
|
| 251 |
+
attn_output = self.process_attention(attn_output, image_rotary_emb, attn_mask, ipadapter_kwargs_list)
|
| 252 |
+
mlp_hidden_states = torch.nn.functional.gelu(mlp_hidden_states, approximate="tanh")
|
| 253 |
+
|
| 254 |
+
hidden_states_a = torch.cat([attn_output, mlp_hidden_states], dim=2)
|
| 255 |
+
hidden_states_a = gate.unsqueeze(1) * self.proj_out(hidden_states_a)
|
| 256 |
+
hidden_states_a = residual + hidden_states_a
|
| 257 |
+
|
| 258 |
+
return hidden_states_a, hidden_states_b
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
class AdaLayerNormContinuous(torch.nn.Module):
|
| 263 |
+
def __init__(self, dim):
|
| 264 |
+
super().__init__()
|
| 265 |
+
self.silu = torch.nn.SiLU()
|
| 266 |
+
self.linear = torch.nn.Linear(dim, dim * 2, bias=True)
|
| 267 |
+
self.norm = torch.nn.LayerNorm(dim, eps=1e-6, elementwise_affine=False)
|
| 268 |
+
|
| 269 |
+
def forward(self, x, conditioning):
|
| 270 |
+
emb = self.linear(self.silu(conditioning))
|
| 271 |
+
shift, scale = torch.chunk(emb, 2, dim=1)
|
| 272 |
+
x = self.norm(x) * (1 + scale)[:, None] + shift[:, None]
|
| 273 |
+
return x
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
class FluxDiT(torch.nn.Module):
|
| 278 |
+
def __init__(self, disable_guidance_embedder=False, input_dim=64, num_blocks=19):
|
| 279 |
+
super().__init__()
|
| 280 |
+
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
|
| 281 |
+
self.time_embedder = TimestepEmbeddings(256, 3072)
|
| 282 |
+
self.guidance_embedder = None if disable_guidance_embedder else TimestepEmbeddings(256, 3072)
|
| 283 |
+
self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072))
|
| 284 |
+
self.context_embedder = torch.nn.Linear(4096, 3072)
|
| 285 |
+
self.x_embedder = torch.nn.Linear(input_dim, 3072)
|
| 286 |
+
|
| 287 |
+
self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(num_blocks)])
|
| 288 |
+
self.single_blocks = torch.nn.ModuleList([FluxSingleTransformerBlock(3072, 24) for _ in range(38)])
|
| 289 |
+
|
| 290 |
+
self.final_norm_out = AdaLayerNormContinuous(3072)
|
| 291 |
+
self.final_proj_out = torch.nn.Linear(3072, 64)
|
| 292 |
+
|
| 293 |
+
self.input_dim = input_dim
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
def patchify(self, hidden_states):
|
| 297 |
+
hidden_states = rearrange(hidden_states, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2)
|
| 298 |
+
return hidden_states
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
def unpatchify(self, hidden_states, height, width):
|
| 302 |
+
hidden_states = rearrange(hidden_states, "B (H W) (C P Q) -> B C (H P) (W Q)", P=2, Q=2, H=height//2, W=width//2)
|
| 303 |
+
return hidden_states
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
def prepare_image_ids(self, latents):
|
| 307 |
+
batch_size, _, height, width = latents.shape
|
| 308 |
+
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
|
| 309 |
+
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
|
| 310 |
+
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
|
| 311 |
+
|
| 312 |
+
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
| 313 |
+
|
| 314 |
+
latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
|
| 315 |
+
latent_image_ids = latent_image_ids.reshape(
|
| 316 |
+
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
| 317 |
+
)
|
| 318 |
+
latent_image_ids = latent_image_ids.to(device=latents.device, dtype=latents.dtype)
|
| 319 |
+
|
| 320 |
+
return latent_image_ids
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
def construct_mask(self, entity_masks, prompt_seq_len, image_seq_len):
|
| 324 |
+
N = len(entity_masks)
|
| 325 |
+
batch_size = entity_masks[0].shape[0]
|
| 326 |
+
total_seq_len = N * prompt_seq_len + image_seq_len
|
| 327 |
+
patched_masks = [self.patchify(entity_masks[i]) for i in range(N)]
|
| 328 |
+
attention_mask = torch.ones((batch_size, total_seq_len, total_seq_len), dtype=torch.bool).to(device=entity_masks[0].device)
|
| 329 |
+
|
| 330 |
+
image_start = N * prompt_seq_len
|
| 331 |
+
image_end = N * prompt_seq_len + image_seq_len
|
| 332 |
+
# prompt-image mask
|
| 333 |
+
for i in range(N):
|
| 334 |
+
prompt_start = i * prompt_seq_len
|
| 335 |
+
prompt_end = (i + 1) * prompt_seq_len
|
| 336 |
+
image_mask = torch.sum(patched_masks[i], dim=-1) > 0
|
| 337 |
+
image_mask = image_mask.unsqueeze(1).repeat(1, prompt_seq_len, 1)
|
| 338 |
+
# prompt update with image
|
| 339 |
+
attention_mask[:, prompt_start:prompt_end, image_start:image_end] = image_mask
|
| 340 |
+
# image update with prompt
|
| 341 |
+
attention_mask[:, image_start:image_end, prompt_start:prompt_end] = image_mask.transpose(1, 2)
|
| 342 |
+
# prompt-prompt mask
|
| 343 |
+
for i in range(N):
|
| 344 |
+
for j in range(N):
|
| 345 |
+
if i != j:
|
| 346 |
+
prompt_start_i = i * prompt_seq_len
|
| 347 |
+
prompt_end_i = (i + 1) * prompt_seq_len
|
| 348 |
+
prompt_start_j = j * prompt_seq_len
|
| 349 |
+
prompt_end_j = (j + 1) * prompt_seq_len
|
| 350 |
+
attention_mask[:, prompt_start_i:prompt_end_i, prompt_start_j:prompt_end_j] = False
|
| 351 |
+
|
| 352 |
+
attention_mask = attention_mask.float()
|
| 353 |
+
attention_mask[attention_mask == 0] = float('-inf')
|
| 354 |
+
attention_mask[attention_mask == 1] = 0
|
| 355 |
+
return attention_mask
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
def process_entity_masks(self, hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids, repeat_dim):
|
| 359 |
+
max_masks = 0
|
| 360 |
+
attention_mask = None
|
| 361 |
+
prompt_embs = [prompt_emb]
|
| 362 |
+
if entity_masks is not None:
|
| 363 |
+
# entity_masks
|
| 364 |
+
batch_size, max_masks = entity_masks.shape[0], entity_masks.shape[1]
|
| 365 |
+
entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1)
|
| 366 |
+
entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)]
|
| 367 |
+
# global mask
|
| 368 |
+
global_mask = torch.ones_like(entity_masks[0]).to(device=hidden_states.device, dtype=hidden_states.dtype)
|
| 369 |
+
entity_masks = entity_masks + [global_mask] # append global to last
|
| 370 |
+
# attention mask
|
| 371 |
+
attention_mask = self.construct_mask(entity_masks, prompt_emb.shape[1], hidden_states.shape[1])
|
| 372 |
+
attention_mask = attention_mask.to(device=hidden_states.device, dtype=hidden_states.dtype)
|
| 373 |
+
attention_mask = attention_mask.unsqueeze(1)
|
| 374 |
+
# embds: n_masks * b * seq * d
|
| 375 |
+
local_embs = [entity_prompt_emb[:, i, None].squeeze(1) for i in range(max_masks)]
|
| 376 |
+
prompt_embs = local_embs + prompt_embs # append global to last
|
| 377 |
+
prompt_embs = [self.context_embedder(prompt_emb) for prompt_emb in prompt_embs]
|
| 378 |
+
prompt_emb = torch.cat(prompt_embs, dim=1)
|
| 379 |
+
|
| 380 |
+
# positional embedding
|
| 381 |
+
text_ids = torch.cat([text_ids] * (max_masks + 1), dim=1)
|
| 382 |
+
image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
|
| 383 |
+
return prompt_emb, image_rotary_emb, attention_mask
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
def forward(
|
| 387 |
+
self,
|
| 388 |
+
hidden_states,
|
| 389 |
+
timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None,
|
| 390 |
+
tiled=False, tile_size=128, tile_stride=64, entity_prompt_emb=None, entity_masks=None,
|
| 391 |
+
use_gradient_checkpointing=False,
|
| 392 |
+
**kwargs
|
| 393 |
+
):
|
| 394 |
+
# (Deprecated) The real forward is in `pipelines.flux_image`.
|
| 395 |
+
return None
|
diffsynth/models/flux_infiniteyou.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
# FFN
|
| 7 |
+
def FeedForward(dim, mult=4):
|
| 8 |
+
inner_dim = int(dim * mult)
|
| 9 |
+
return nn.Sequential(
|
| 10 |
+
nn.LayerNorm(dim),
|
| 11 |
+
nn.Linear(dim, inner_dim, bias=False),
|
| 12 |
+
nn.GELU(),
|
| 13 |
+
nn.Linear(inner_dim, dim, bias=False),
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def reshape_tensor(x, heads):
|
| 18 |
+
bs, length, width = x.shape
|
| 19 |
+
#(bs, length, width) --> (bs, length, n_heads, dim_per_head)
|
| 20 |
+
x = x.view(bs, length, heads, -1)
|
| 21 |
+
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
|
| 22 |
+
x = x.transpose(1, 2)
|
| 23 |
+
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
|
| 24 |
+
x = x.reshape(bs, heads, length, -1)
|
| 25 |
+
return x
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class PerceiverAttention(nn.Module):
|
| 29 |
+
|
| 30 |
+
def __init__(self, *, dim, dim_head=64, heads=8):
|
| 31 |
+
super().__init__()
|
| 32 |
+
self.scale = dim_head**-0.5
|
| 33 |
+
self.dim_head = dim_head
|
| 34 |
+
self.heads = heads
|
| 35 |
+
inner_dim = dim_head * heads
|
| 36 |
+
|
| 37 |
+
self.norm1 = nn.LayerNorm(dim)
|
| 38 |
+
self.norm2 = nn.LayerNorm(dim)
|
| 39 |
+
|
| 40 |
+
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
| 41 |
+
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
|
| 42 |
+
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
| 43 |
+
|
| 44 |
+
def forward(self, x, latents):
|
| 45 |
+
"""
|
| 46 |
+
Args:
|
| 47 |
+
x (torch.Tensor): image features
|
| 48 |
+
shape (b, n1, D)
|
| 49 |
+
latent (torch.Tensor): latent features
|
| 50 |
+
shape (b, n2, D)
|
| 51 |
+
"""
|
| 52 |
+
x = self.norm1(x)
|
| 53 |
+
latents = self.norm2(latents)
|
| 54 |
+
|
| 55 |
+
b, l, _ = latents.shape
|
| 56 |
+
|
| 57 |
+
q = self.to_q(latents)
|
| 58 |
+
kv_input = torch.cat((x, latents), dim=-2)
|
| 59 |
+
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
| 60 |
+
|
| 61 |
+
q = reshape_tensor(q, self.heads)
|
| 62 |
+
k = reshape_tensor(k, self.heads)
|
| 63 |
+
v = reshape_tensor(v, self.heads)
|
| 64 |
+
|
| 65 |
+
# attention
|
| 66 |
+
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
|
| 67 |
+
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
|
| 68 |
+
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
| 69 |
+
out = weight @ v
|
| 70 |
+
|
| 71 |
+
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
|
| 72 |
+
|
| 73 |
+
return self.to_out(out)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class InfiniteYouImageProjector(nn.Module):
|
| 77 |
+
|
| 78 |
+
def __init__(
|
| 79 |
+
self,
|
| 80 |
+
dim=1280,
|
| 81 |
+
depth=4,
|
| 82 |
+
dim_head=64,
|
| 83 |
+
heads=20,
|
| 84 |
+
num_queries=8,
|
| 85 |
+
embedding_dim=512,
|
| 86 |
+
output_dim=4096,
|
| 87 |
+
ff_mult=4,
|
| 88 |
+
):
|
| 89 |
+
super().__init__()
|
| 90 |
+
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
|
| 91 |
+
self.proj_in = nn.Linear(embedding_dim, dim)
|
| 92 |
+
|
| 93 |
+
self.proj_out = nn.Linear(dim, output_dim)
|
| 94 |
+
self.norm_out = nn.LayerNorm(output_dim)
|
| 95 |
+
|
| 96 |
+
self.layers = nn.ModuleList([])
|
| 97 |
+
for _ in range(depth):
|
| 98 |
+
self.layers.append(
|
| 99 |
+
nn.ModuleList([
|
| 100 |
+
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
|
| 101 |
+
FeedForward(dim=dim, mult=ff_mult),
|
| 102 |
+
]))
|
| 103 |
+
|
| 104 |
+
def forward(self, x):
|
| 105 |
+
|
| 106 |
+
latents = self.latents.repeat(x.size(0), 1, 1)
|
| 107 |
+
latents = latents.to(dtype=x.dtype, device=x.device)
|
| 108 |
+
|
| 109 |
+
x = self.proj_in(x)
|
| 110 |
+
|
| 111 |
+
for attn, ff in self.layers:
|
| 112 |
+
latents = attn(x, latents) + latents
|
| 113 |
+
latents = ff(latents) + latents
|
| 114 |
+
|
| 115 |
+
latents = self.proj_out(latents)
|
| 116 |
+
return self.norm_out(latents)
|
| 117 |
+
|
| 118 |
+
@staticmethod
|
| 119 |
+
def state_dict_converter():
|
| 120 |
+
return FluxInfiniteYouImageProjectorStateDictConverter()
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class FluxInfiniteYouImageProjectorStateDictConverter:
|
| 124 |
+
|
| 125 |
+
def __init__(self):
|
| 126 |
+
pass
|
| 127 |
+
|
| 128 |
+
def from_diffusers(self, state_dict):
|
| 129 |
+
return state_dict['image_proj']
|
diffsynth/models/flux_ipadapter.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .general_modules import RMSNorm
|
| 2 |
+
from transformers import SiglipVisionModel, SiglipVisionConfig
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class SiglipVisionModelSO400M(SiglipVisionModel):
|
| 7 |
+
def __init__(self):
|
| 8 |
+
config = SiglipVisionConfig(
|
| 9 |
+
hidden_size=1152,
|
| 10 |
+
image_size=384,
|
| 11 |
+
intermediate_size=4304,
|
| 12 |
+
model_type="siglip_vision_model",
|
| 13 |
+
num_attention_heads=16,
|
| 14 |
+
num_hidden_layers=27,
|
| 15 |
+
patch_size=14,
|
| 16 |
+
architectures=["SiglipModel"],
|
| 17 |
+
initializer_factor=1.0,
|
| 18 |
+
torch_dtype="float32",
|
| 19 |
+
transformers_version="4.37.0.dev0"
|
| 20 |
+
)
|
| 21 |
+
super().__init__(config)
|
| 22 |
+
|
| 23 |
+
class MLPProjModel(torch.nn.Module):
|
| 24 |
+
def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4):
|
| 25 |
+
super().__init__()
|
| 26 |
+
|
| 27 |
+
self.cross_attention_dim = cross_attention_dim
|
| 28 |
+
self.num_tokens = num_tokens
|
| 29 |
+
|
| 30 |
+
self.proj = torch.nn.Sequential(
|
| 31 |
+
torch.nn.Linear(id_embeddings_dim, id_embeddings_dim*2),
|
| 32 |
+
torch.nn.GELU(),
|
| 33 |
+
torch.nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens),
|
| 34 |
+
)
|
| 35 |
+
self.norm = torch.nn.LayerNorm(cross_attention_dim)
|
| 36 |
+
|
| 37 |
+
def forward(self, id_embeds):
|
| 38 |
+
x = self.proj(id_embeds)
|
| 39 |
+
x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
|
| 40 |
+
x = self.norm(x)
|
| 41 |
+
return x
|
| 42 |
+
|
| 43 |
+
class IpAdapterModule(torch.nn.Module):
|
| 44 |
+
def __init__(self, num_attention_heads, attention_head_dim, input_dim):
|
| 45 |
+
super().__init__()
|
| 46 |
+
self.num_heads = num_attention_heads
|
| 47 |
+
self.head_dim = attention_head_dim
|
| 48 |
+
output_dim = num_attention_heads * attention_head_dim
|
| 49 |
+
self.to_k_ip = torch.nn.Linear(input_dim, output_dim, bias=False)
|
| 50 |
+
self.to_v_ip = torch.nn.Linear(input_dim, output_dim, bias=False)
|
| 51 |
+
self.norm_added_k = RMSNorm(attention_head_dim, eps=1e-5, elementwise_affine=False)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def forward(self, hidden_states):
|
| 55 |
+
batch_size = hidden_states.shape[0]
|
| 56 |
+
# ip_k
|
| 57 |
+
ip_k = self.to_k_ip(hidden_states)
|
| 58 |
+
ip_k = ip_k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
| 59 |
+
ip_k = self.norm_added_k(ip_k)
|
| 60 |
+
# ip_v
|
| 61 |
+
ip_v = self.to_v_ip(hidden_states)
|
| 62 |
+
ip_v = ip_v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
| 63 |
+
return ip_k, ip_v
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class FluxIpAdapter(torch.nn.Module):
|
| 67 |
+
def __init__(self, num_attention_heads=24, attention_head_dim=128, cross_attention_dim=4096, num_tokens=128, num_blocks=57):
|
| 68 |
+
super().__init__()
|
| 69 |
+
self.ipadapter_modules = torch.nn.ModuleList([IpAdapterModule(num_attention_heads, attention_head_dim, cross_attention_dim) for _ in range(num_blocks)])
|
| 70 |
+
self.image_proj = MLPProjModel(cross_attention_dim=cross_attention_dim, id_embeddings_dim=1152, num_tokens=num_tokens)
|
| 71 |
+
self.set_adapter()
|
| 72 |
+
|
| 73 |
+
def set_adapter(self):
|
| 74 |
+
self.call_block_id = {i:i for i in range(len(self.ipadapter_modules))}
|
| 75 |
+
|
| 76 |
+
def forward(self, hidden_states, scale=1.0):
|
| 77 |
+
hidden_states = self.image_proj(hidden_states)
|
| 78 |
+
hidden_states = hidden_states.view(1, -1, hidden_states.shape[-1])
|
| 79 |
+
ip_kv_dict = {}
|
| 80 |
+
for block_id in self.call_block_id:
|
| 81 |
+
ipadapter_id = self.call_block_id[block_id]
|
| 82 |
+
ip_k, ip_v = self.ipadapter_modules[ipadapter_id](hidden_states)
|
| 83 |
+
ip_kv_dict[block_id] = {
|
| 84 |
+
"ip_k": ip_k,
|
| 85 |
+
"ip_v": ip_v,
|
| 86 |
+
"scale": scale
|
| 87 |
+
}
|
| 88 |
+
return ip_kv_dict
|
| 89 |
+
|
| 90 |
+
@staticmethod
|
| 91 |
+
def state_dict_converter():
|
| 92 |
+
return FluxIpAdapterStateDictConverter()
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class FluxIpAdapterStateDictConverter:
|
| 96 |
+
def __init__(self):
|
| 97 |
+
pass
|
| 98 |
+
|
| 99 |
+
def from_diffusers(self, state_dict):
|
| 100 |
+
state_dict_ = {}
|
| 101 |
+
for name in state_dict["ip_adapter"]:
|
| 102 |
+
name_ = 'ipadapter_modules.' + name
|
| 103 |
+
state_dict_[name_] = state_dict["ip_adapter"][name]
|
| 104 |
+
for name in state_dict["image_proj"]:
|
| 105 |
+
name_ = "image_proj." + name
|
| 106 |
+
state_dict_[name_] = state_dict["image_proj"][name]
|
| 107 |
+
return state_dict_
|
| 108 |
+
|
| 109 |
+
def from_civitai(self, state_dict):
|
| 110 |
+
return self.from_diffusers(state_dict)
|