PencilHu commited on
Commit
1146a67
·
verified ·
1 Parent(s): b0369f9

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +115 -0
  2. .github/workflows/logo.gif +3 -0
  3. .github/workflows/publish.yaml +29 -0
  4. .gitignore +175 -0
  5. .home/.modelscope/credentials/session +1 -0
  6. .swp +0 -0
  7. LICENSE +201 -0
  8. README.md +783 -0
  9. README_zh.md +784 -0
  10. assets/egg.mp4 +3 -0
  11. comp_attn_bbox_layout.png +0 -0
  12. comp_attn_trajectory.png +0 -0
  13. diffsynth/__init__.py +1 -0
  14. diffsynth/configs/__init__.py +2 -0
  15. diffsynth/configs/model_configs.py +518 -0
  16. diffsynth/configs/vram_management_module_maps.py +197 -0
  17. diffsynth/core/__init__.py +5 -0
  18. diffsynth/core/attention/__init__.py +1 -0
  19. diffsynth/core/attention/attention.py +121 -0
  20. diffsynth/core/data/__init__.py +1 -0
  21. diffsynth/core/data/operators.py +218 -0
  22. diffsynth/core/data/unified_dataset.py +112 -0
  23. diffsynth/core/gradient/__init__.py +1 -0
  24. diffsynth/core/gradient/gradient_checkpoint.py +34 -0
  25. diffsynth/core/loader/__init__.py +3 -0
  26. diffsynth/core/loader/config.py +117 -0
  27. diffsynth/core/loader/file.py +121 -0
  28. diffsynth/core/loader/model.py +79 -0
  29. diffsynth/core/vram/__init__.py +2 -0
  30. diffsynth/core/vram/disk_map.py +93 -0
  31. diffsynth/core/vram/initialization.py +21 -0
  32. diffsynth/core/vram/layers.py +475 -0
  33. diffsynth/datasets/mvdataset.py +393 -0
  34. diffsynth/diffusion/__init__.py +6 -0
  35. diffsynth/diffusion/base_pipeline.py +439 -0
  36. diffsynth/diffusion/flow_match.py +179 -0
  37. diffsynth/diffusion/logger.py +43 -0
  38. diffsynth/diffusion/loss.py +119 -0
  39. diffsynth/diffusion/parsers.py +70 -0
  40. diffsynth/diffusion/runner.py +129 -0
  41. diffsynth/diffusion/training_module.py +212 -0
  42. diffsynth/models/comp_attn_model.py +592 -0
  43. diffsynth/models/dinov3_image_encoder.py +94 -0
  44. diffsynth/models/flux2_dit.py +1057 -0
  45. diffsynth/models/flux2_text_encoder.py +58 -0
  46. diffsynth/models/flux2_vae.py +0 -0
  47. diffsynth/models/flux_controlnet.py +384 -0
  48. diffsynth/models/flux_dit.py +395 -0
  49. diffsynth/models/flux_infiniteyou.py +129 -0
  50. diffsynth/models/flux_ipadapter.py +110 -0
.gitattributes CHANGED
@@ -33,3 +33,118 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ .github/workflows/logo.gif filter=lfs diff=lfs merge=lfs -text
37
+ assets/egg.mp4 filter=lfs diff=lfs merge=lfs -text
38
+ examples/Comp-Attn.pdf filter=lfs diff=lfs merge=lfs -text
39
+ examples/InstanceV.pdf filter=lfs diff=lfs merge=lfs -text
40
+ examples/wanvideo/model_training/egg_statemachine_dataset/egg_4fps_640x320.mp4 filter=lfs diff=lfs merge=lfs -text
41
+ examples/wanvideo/model_training/egg_statemachine_dataset/egg_5fps_1280x720.mp4 filter=lfs diff=lfs merge=lfs -text
42
+ examples/wanvideo/model_training/egg_statemachine_dataset/egg_5fps_640x320.mp4 filter=lfs diff=lfs merge=lfs -text
43
+ examples/wanvideo/model_training/egg_statemachine_dataset/egg_5fps_NonexNone.mp4 filter=lfs diff=lfs merge=lfs -text
44
+ examples/wanvideo/model_training/egg_statemachine_dataset/egg_8fps_448x256.mp4 filter=lfs diff=lfs merge=lfs -text
45
+ output/_smoke_352x640/p03_seed2_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-as-the.mp4 filter=lfs diff=lfs merge=lfs -text
46
+ output/_smoke_352x640/p07_seed6_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-on-a-mo.mp4 filter=lfs diff=lfs merge=lfs -text
47
+ output/_smoke_352x640/p08_seed7_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-in-a-vi.mp4 filter=lfs diff=lfs merge=lfs -text
48
+ output/_smoke_352x640/p09_seed8_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-as-the.mp4 filter=lfs diff=lfs merge=lfs -text
49
+ output/_smoke_352x640/p10_seed9_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-near-a.mp4 filter=lfs diff=lfs merge=lfs -text
50
+ output/wan2.1-1.3b-mc-lora/p01_seed0_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-inside.mp4 filter=lfs diff=lfs merge=lfs -text
51
+ output/wan2.1-1.3b-mc-lora/p02_seed1_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-at-a-vi.mp4 filter=lfs diff=lfs merge=lfs -text
52
+ output/wan2.1-1.3b-mc-lora/p03_seed2_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-in-fron.mp4 filter=lfs diff=lfs merge=lfs -text
53
+ output/wan2.1-1.3b-mc-lora/p04_seed3_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-next-to.mp4 filter=lfs diff=lfs merge=lfs -text
54
+ output/wan2.1-1.3b-mc-lora/p05_seed4_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-inside.mp4 filter=lfs diff=lfs merge=lfs -text
55
+ output/wan2.1-1.3b-mc-lora/p06_seed5_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-at-a-ca.mp4 filter=lfs diff=lfs merge=lfs -text
56
+ output/wan2.1-1.3b-mc-lora/p07_seed6_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-in-a-vi.mp4 filter=lfs diff=lfs merge=lfs -text
57
+ output/wan2.1-1.3b-mc-lora/p08_seed7_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-at-the.mp4 filter=lfs diff=lfs merge=lfs -text
58
+ output/wan2.1-1.3b-mc-lora/p09_seed8_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-in-a-sm.mp4 filter=lfs diff=lfs merge=lfs -text
59
+ output/wan2.1-1.3b-mc-lora/p10_seed9_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-on-a-vi.mp4 filter=lfs diff=lfs merge=lfs -text
60
+ output/wan2.1-1.3b-mc-lora_352x640_stable/p01_seed0_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-charact.mp4 filter=lfs diff=lfs merge=lfs -text
61
+ output/wan2.1-1.3b-mc-lora_352x640_stable/p02_seed1_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-where-t.mp4 filter=lfs diff=lfs merge=lfs -text
62
+ output/wan2.1-1.3b-mc-lora_352x640_stable/p03_seed2_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-as-the.mp4 filter=lfs diff=lfs merge=lfs -text
63
+ output/wan2.1-1.3b-mc-lora_352x640_stable/p04_seed3_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-set-in.mp4 filter=lfs diff=lfs merge=lfs -text
64
+ output/wan2.1-1.3b-mc-lora_352x640_stable/p05_seed4_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-during.mp4 filter=lfs diff=lfs merge=lfs -text
65
+ output/wan2.1-1.3b-mc-lora_352x640_stable/p06_seed5_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-as-the.mp4 filter=lfs diff=lfs merge=lfs -text
66
+ output/wan2.1-1.3b-mc-lora_352x640_stable/p07_seed6_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-on-a-mo.mp4 filter=lfs diff=lfs merge=lfs -text
67
+ output/wan2.1-1.3b-mc-lora_352x640_stable/p08_seed7_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-in-a-vi.mp4 filter=lfs diff=lfs merge=lfs -text
68
+ output/wan2.1-1.3b-mc-lora_352x640_stable/p09_seed8_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-as-the.mp4 filter=lfs diff=lfs merge=lfs -text
69
+ output/wan2.1-1.3b-mc-lora_352x640_stable/p10_seed9_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-near-a.mp4 filter=lfs diff=lfs merge=lfs -text
70
+ output/wan2.1-1.3b-mc-lora_batch10/p01_seed0_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-charact.mp4 filter=lfs diff=lfs merge=lfs -text
71
+ output/wan2.1-1.3b-mc-lora_batch10/p02_seed1_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-where-t.mp4 filter=lfs diff=lfs merge=lfs -text
72
+ output/wan2.1-1.3b-mc-lora_batch10/p03_seed2_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-as-the.mp4 filter=lfs diff=lfs merge=lfs -text
73
+ output/wan2.1-1.3b-mc-lora_batch10/p04_seed3_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-set-in.mp4 filter=lfs diff=lfs merge=lfs -text
74
+ output/wan2.1-1.3b-mc-lora_batch10/p05_seed4_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-during.mp4 filter=lfs diff=lfs merge=lfs -text
75
+ output/wan2.1-1.3b-mc-lora_batch10/p06_seed5_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-as-the.mp4 filter=lfs diff=lfs merge=lfs -text
76
+ output/wan2.1-1.3b-mc-lora_batch10/p07_seed6_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-on-a-mo.mp4 filter=lfs diff=lfs merge=lfs -text
77
+ output/wan2.1-1.3b-mc-lora_batch10/p08_seed7_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-in-a-vi.mp4 filter=lfs diff=lfs merge=lfs -text
78
+ output/wan2.1-1.3b-mc-lora_batch10/p09_seed8_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-as-the.mp4 filter=lfs diff=lfs merge=lfs -text
79
+ output/wan2.1-1.3b-mc-lora_batch10/p10_seed9_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-near-a.mp4 filter=lfs diff=lfs merge=lfs -text
80
+ output/wan2.1-1.3b-mc-lora_epoch4_1/p01_seed0_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-charact.mp4 filter=lfs diff=lfs merge=lfs -text
81
+ output/wan2.1-1.3b-mc-lora_epoch4_1/p02_seed1_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-where-t.mp4 filter=lfs diff=lfs merge=lfs -text
82
+ output/wan2.1-1.3b-mc-lora_epoch4_1/p03_seed2_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-as-the.mp4 filter=lfs diff=lfs merge=lfs -text
83
+ output/wan2.1-1.3b-mc-lora_epoch4_1/p04_seed3_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-set-in.mp4 filter=lfs diff=lfs merge=lfs -text
84
+ output/wan2.1-1.3b-mc-lora_epoch4_1/p05_seed4_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-during.mp4 filter=lfs diff=lfs merge=lfs -text
85
+ output/wan2.1-1.3b-mc-lora_epoch4_1/p06_seed5_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-as-the.mp4 filter=lfs diff=lfs merge=lfs -text
86
+ output/wan2.1-1.3b-mc-lora_epoch4_1/p07_seed6_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-on-a-mo.mp4 filter=lfs diff=lfs merge=lfs -text
87
+ output/wan2.1-1.3b-mc-lora_epoch4_1/p08_seed7_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-in-a-vi.mp4 filter=lfs diff=lfs merge=lfs -text
88
+ output/wan2.1-1.3b-mc-lora_epoch4_1/p09_seed8_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-as-the.mp4 filter=lfs diff=lfs merge=lfs -text
89
+ output/wan2.1-1.3b-mc-lora_epoch4_1/p10_seed9_the-video-showcases-a-first-person-perspective-within-the-game-minecraft-near-a.mp4 filter=lfs diff=lfs merge=lfs -text
90
+ output/wan2.1-1.3b-statemachine-egg/egg_statemachine_infer.mp4 filter=lfs diff=lfs merge=lfs -text
91
+ output/wan2.1-1.3b-statemachine-egg_cooked2raw/egg_statemachine_infer.mp4 filter=lfs diff=lfs merge=lfs -text
92
+ output/wan2.1-1.3b-statemachine-egg_moveup_long/egg_statemachine_infer.mp4 filter=lfs diff=lfs merge=lfs -text
93
+ output/wan2.1-1.3b-statemachine-egg_moveup_long20/egg_statemachine_infer.mp4 filter=lfs diff=lfs merge=lfs -text
94
+ output/wan2.1-1.3b-statemachine-egg_moveup_long20_promptclean/egg_statemachine_infer.mp4 filter=lfs diff=lfs merge=lfs -text
95
+ outputs/instancev/boat_seagull_20260105_114652.mp4 filter=lfs diff=lfs merge=lfs -text
96
+ outputs/instancev/deer_approach_20260105_114652.mp4 filter=lfs diff=lfs merge=lfs -text
97
+ outputs/instancev/dog_running.mp4 filter=lfs diff=lfs merge=lfs -text
98
+ outputs/instancev/dog_running_baseline.mp4 filter=lfs diff=lfs merge=lfs -text
99
+ outputs/instancev/four_people_talking.mp4 filter=lfs diff=lfs merge=lfs -text
100
+ outputs/instancev/four_pigeons_orbit_20260105_114652.mp4 filter=lfs diff=lfs merge=lfs -text
101
+ outputs/instancev/multi_instances_animals.mp4 filter=lfs diff=lfs merge=lfs -text
102
+ outputs/instancev/single_car_sweep.mp4 filter=lfs diff=lfs merge=lfs -text
103
+ outputs/instancev/three_diagonal_motion.mp4 filter=lfs diff=lfs merge=lfs -text
104
+ outputs/instancev/two_crossing_athletes.mp4 filter=lfs diff=lfs merge=lfs -text
105
+ outputs/instancev/two_people_talking.mp4 filter=lfs diff=lfs merge=lfs -text
106
+ outputs/instancev/two_scooters_crossing_20260105_112906.mp4 filter=lfs diff=lfs merge=lfs -text
107
+ outputs/instancev/two_scooters_crossing_20260105_113836.mp4 filter=lfs diff=lfs merge=lfs -text
108
+ outputs/instancev/two_scooters_crossing_20260105_114652.mp4 filter=lfs diff=lfs merge=lfs -text
109
+ outputs/instancev/two_students_drone_20260105_114652.mp4 filter=lfs diff=lfs merge=lfs -text
110
+ outputs/instancev-new/case_00.mp4 filter=lfs diff=lfs merge=lfs -text
111
+ outputs/instancev-new/case_01.mp4 filter=lfs diff=lfs merge=lfs -text
112
+ outputs/instancev-new/case_02.mp4 filter=lfs diff=lfs merge=lfs -text
113
+ outputs/instancev-new/case_03.mp4 filter=lfs diff=lfs merge=lfs -text
114
+ outputs/instancev-new/case_04.mp4 filter=lfs diff=lfs merge=lfs -text
115
+ outputs/instancev-new/case_05.mp4 filter=lfs diff=lfs merge=lfs -text
116
+ outputs/instancev-new/case_06.mp4 filter=lfs diff=lfs merge=lfs -text
117
+ outputs/instancev-new/case_07.mp4 filter=lfs diff=lfs merge=lfs -text
118
+ outputs/instancev-new/case_08.mp4 filter=lfs diff=lfs merge=lfs -text
119
+ outputs/instancev-new/case_09.mp4 filter=lfs diff=lfs merge=lfs -text
120
+ outputs/instancev-new/case_10.mp4 filter=lfs diff=lfs merge=lfs -text
121
+ outputs/instancev-new/case_11.mp4 filter=lfs diff=lfs merge=lfs -text
122
+ outputs/instancev-new/case_12.mp4 filter=lfs diff=lfs merge=lfs -text
123
+ outputs/instancev-new/case_13.mp4 filter=lfs diff=lfs merge=lfs -text
124
+ outputs/instancev-new/case_14.mp4 filter=lfs diff=lfs merge=lfs -text
125
+ outputs/instancev-new/case_15.mp4 filter=lfs diff=lfs merge=lfs -text
126
+ outputs/instancev-new/case_16.mp4 filter=lfs diff=lfs merge=lfs -text
127
+ outputs/instancev-new/case_17.mp4 filter=lfs diff=lfs merge=lfs -text
128
+ outputs/instancev-new/case_18.mp4 filter=lfs diff=lfs merge=lfs -text
129
+ outputs/instancev-new/case_19.mp4 filter=lfs diff=lfs merge=lfs -text
130
+ outputs/instancev_iground_infer_20260107_024437.mp4 filter=lfs diff=lfs merge=lfs -text
131
+ outputs/instancev_iground_test.mp4 filter=lfs diff=lfs merge=lfs -text
132
+ video.mp4 filter=lfs diff=lfs merge=lfs -text
133
+ video_1_Wan2.1-T2V-1.3B-LoRA2.mp4 filter=lfs diff=lfs merge=lfs -text
134
+ video_1_Wan2.1-T2V-1.3B-LoRA_epoch1.mp4 filter=lfs diff=lfs merge=lfs -text
135
+ video_1_Wan2.1-T2V-1.3B.mp4 filter=lfs diff=lfs merge=lfs -text
136
+ video_1_Wan2.1-T2V-1.3B_LoRA.mp4 filter=lfs diff=lfs merge=lfs -text
137
+ video_comp_attn_pipeline[[:space:]]copy.mp4 filter=lfs diff=lfs merge=lfs -text
138
+ video_comp_attn_pipeline.mp4 filter=lfs diff=lfs merge=lfs -text
139
+ wandb/run-20251211_101851-syoqkmhy/run-syoqkmhy.wandb filter=lfs diff=lfs merge=lfs -text
140
+ wandb/run-20251211_172331-jxaicuod/run-jxaicuod.wandb filter=lfs diff=lfs merge=lfs -text
141
+ wandb/run-20251225_172459-gjtz0um5/run-gjtz0um5.wandb filter=lfs diff=lfs merge=lfs -text
142
+ wandb/run-20251225_214534-3dh8lbav/run-3dh8lbav.wandb filter=lfs diff=lfs merge=lfs -text
143
+ wandb/run-20251229_100816-zirij84a/run-zirij84a.wandb filter=lfs diff=lfs merge=lfs -text
144
+ wandb/run-20260102_054910-38oaloji/run-38oaloji.wandb filter=lfs diff=lfs merge=lfs -text
145
+ wandb/run-20260102_104929-zd02vtce/run-zd02vtce.wandb filter=lfs diff=lfs merge=lfs -text
146
+ wandb/run-20260102_162705-mr7vgtqn/run-mr7vgtqn.wandb filter=lfs diff=lfs merge=lfs -text
147
+ wandb/run-20260103_090415-36yjbun5/run-36yjbun5.wandb filter=lfs diff=lfs merge=lfs -text
148
+ wandb/run-20260103_115016-kurow4tk/run-kurow4tk.wandb filter=lfs diff=lfs merge=lfs -text
149
+ wandb/run-20260106_030539-rupbhtts/run-rupbhtts.wandb filter=lfs diff=lfs merge=lfs -text
150
+ wandb/run-20260110_110203-bl4gd6wi/run-bl4gd6wi.wandb filter=lfs diff=lfs merge=lfs -text
.github/workflows/logo.gif ADDED

Git LFS Details

  • SHA256: 36a7627b7f0f0a508ec64aba72e5d95d38dfe7958bd8cf42d2a63f6ac2641529
  • Pointer size: 131 Bytes
  • Size of remote file: 149 kB
.github/workflows/publish.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: release
2
+
3
+ on:
4
+ push:
5
+ tags:
6
+ - 'v**'
7
+
8
+ concurrency:
9
+ group: ${{ github.workflow }}-${{ github.ref }}-publish
10
+ cancel-in-progress: true
11
+
12
+ jobs:
13
+ build-n-publish:
14
+ runs-on: ubuntu-20.04
15
+ #if: startsWith(github.event.ref, 'refs/tags')
16
+ steps:
17
+ - uses: actions/checkout@v2
18
+ - name: Set up Python 3.10
19
+ uses: actions/setup-python@v2
20
+ with:
21
+ python-version: '3.10'
22
+ - name: Install wheel
23
+ run: pip install wheel==0.44.0 && pip install -r requirements.txt
24
+ - name: Build DiffSynth
25
+ run: python setup.py sdist bdist_wheel
26
+ - name: Publish package to PyPI
27
+ run: |
28
+ pip install twine
29
+ twine upload dist/* --skip-existing -u __token__ -p ${{ secrets.PYPI_API_TOKEN }}
.gitignore ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /data
2
+ /models
3
+ /scripts
4
+ /diffusers
5
+ *.pkl
6
+ *.safetensors
7
+ *.pth
8
+ *.ckpt
9
+ *.pt
10
+ *.bin
11
+ *.DS_Store
12
+ *.msc
13
+ *.mv
14
+ log*.txt
15
+
16
+ # Byte-compiled / optimized / DLL files
17
+ __pycache__/
18
+ *.py[cod]
19
+ *$py.class
20
+
21
+ # C extensions
22
+ *.so
23
+
24
+ # Distribution / packaging
25
+ .Python
26
+ build/
27
+ develop-eggs/
28
+ dist/
29
+ downloads/
30
+ eggs/
31
+ .eggs/
32
+ lib/
33
+ lib64/
34
+ parts/
35
+ sdist/
36
+ var/
37
+ wheels/
38
+ share/python-wheels/
39
+ *.egg-info/
40
+ .installed.cfg
41
+ *.egg
42
+ MANIFEST
43
+
44
+ # PyInstaller
45
+ # Usually these files are written by a python script from a template
46
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
47
+ *.manifest
48
+ *.spec
49
+
50
+ # Installer logs
51
+ pip-log.txt
52
+ pip-delete-this-directory.txt
53
+
54
+ # Unit test / coverage reports
55
+ htmlcov/
56
+ .tox/
57
+ .nox/
58
+ .coverage
59
+ .coverage.*
60
+ .cache
61
+ nosetests.xml
62
+ coverage.xml
63
+ *.cover
64
+ *.py,cover
65
+ .hypothesis/
66
+ .pytest_cache/
67
+ cover/
68
+
69
+ # Translations
70
+ *.mo
71
+ *.pot
72
+
73
+ # Django stuff:
74
+ *.log
75
+ local_settings.py
76
+ db.sqlite3
77
+ db.sqlite3-journal
78
+
79
+ # Flask stuff:
80
+ instance/
81
+ .webassets-cache
82
+
83
+ # Scrapy stuff:
84
+ .scrapy
85
+
86
+ # Sphinx documentation
87
+ docs/_build/
88
+
89
+ # PyBuilder
90
+ .pybuilder/
91
+ target/
92
+
93
+ # Jupyter Notebook
94
+ .ipynb_checkpoints
95
+
96
+ # IPython
97
+ profile_default/
98
+ ipython_config.py
99
+
100
+ # pyenv
101
+ # For a library or package, you might want to ignore these files since the code is
102
+ # intended to run in multiple environments; otherwise, check them in:
103
+ # .python-version
104
+
105
+ # pipenv
106
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
107
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
108
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
109
+ # install all needed dependencies.
110
+ #Pipfile.lock
111
+
112
+ # poetry
113
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
114
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
115
+ # commonly ignored for libraries.
116
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
117
+ #poetry.lock
118
+
119
+ # pdm
120
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
121
+ #pdm.lock
122
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
123
+ # in version control.
124
+ # https://pdm.fming.dev/#use-with-ide
125
+ .pdm.toml
126
+
127
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
128
+ __pypackages__/
129
+
130
+ # Celery stuff
131
+ celerybeat-schedule
132
+ celerybeat.pid
133
+
134
+ # SageMath parsed files
135
+ *.sage.py
136
+
137
+ # Environments
138
+ .env
139
+ .venv
140
+ env/
141
+ venv/
142
+ ENV/
143
+ env.bak/
144
+ venv.bak/
145
+
146
+ # Spyder project settings
147
+ .spyderproject
148
+ .spyproject
149
+
150
+ # Rope project settings
151
+ .ropeproject
152
+
153
+ # mkdocs documentation
154
+ /site
155
+
156
+ # mypy
157
+ .mypy_cache/
158
+ .dmypy.json
159
+ dmypy.json
160
+
161
+ # Pyre type checker
162
+ .pyre/
163
+
164
+ # pytype static type analyzer
165
+ .pytype/
166
+
167
+ # Cython debug symbols
168
+ cython_debug/
169
+
170
+ # PyCharm
171
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
172
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
173
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
174
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
175
+ #.idea/
.home/.modelscope/credentials/session ADDED
@@ -0,0 +1 @@
 
 
1
+ 13921be3c1924b38a5a21db02dce6b94
.swp ADDED
Binary file (12.3 kB). View file
 
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [2023] [Zhongjie Duan]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md ADDED
@@ -0,0 +1,783 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DiffSynth-Studio
2
+
3
+ <a href="https://github.com/modelscope/DiffSynth-Studio"><img src=".github/workflows/logo.gif" title="Logo" style="max-width:100%;" width="55" /></a> <a href="https://trendshift.io/repositories/10946" target="_blank"><img src="https://trendshift.io/api/badge/repositories/10946" alt="modelscope%2FDiffSynth-Studio | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a></p>
4
+
5
+ [![PyPI](https://img.shields.io/pypi/v/DiffSynth)](https://pypi.org/project/DiffSynth/)
6
+ [![license](https://img.shields.io/github/license/modelscope/DiffSynth-Studio.svg)](https://github.com/modelscope/DiffSynth-Studio/blob/master/LICENSE)
7
+ [![open issues](https://isitmaintained.com/badge/open/modelscope/DiffSynth-Studio.svg)](https://github.com/modelscope/DiffSynth-Studio/issues)
8
+ [![GitHub pull-requests](https://img.shields.io/github/issues-pr/modelscope/DiffSynth-Studio.svg)](https://GitHub.com/modelscope/DiffSynth-Studio/pull/)
9
+ [![GitHub latest commit](https://badgen.net/github/last-commit/modelscope/DiffSynth-Studio)](https://GitHub.com/modelscope/DiffSynth-Studio/commit/)
10
+
11
+ [切换到中文版](./README_zh.md)
12
+
13
+ ## Introduction
14
+
15
+ Welcome to the magical world of Diffusion models! DiffSynth-Studio is an open-source Diffusion model engine developed and maintained by the [ModelScope Community](https://www.modelscope.cn/). We hope to foster technological innovation through framework construction, aggregate the power of the open-source community, and explore the boundaries of generative model technology!
16
+
17
+ DiffSynth currently includes two open-source projects:
18
+ * [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio): Focused on aggressive technical exploration, targeting academia, and providing cutting-edge model capability support.
19
+ * [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine): Focused on stable model deployment, targeting industry, and providing higher computational performance and more stable features.
20
+
21
+ [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio) and [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine) are the core engines of the ModelScope AIGC zone. Welcome to experience our carefully crafted productized features:
22
+
23
+ * ModelScope AIGC Zone (for Chinese users): https://modelscope.cn/aigc/home
24
+ * ModelScope Civision (for global users): https://modelscope.ai/civision/home
25
+
26
+ > DiffSynth-Studio Documentation: [中文版](/docs/zh/README.md)、[English version](/docs/en/README.md)
27
+
28
+ We believe that a well-developed open-source code framework can lower the threshold for technical exploration. We have achieved many [interesting technologies](#innovative-achievements) based on this codebase. Perhaps you also have many wild ideas, and with DiffSynth-Studio, you can quickly realize these ideas. For this reason, we have prepared detailed documentation for developers. We hope that through these documents, developers can understand the principles of Diffusion models, and we look forward to expanding the boundaries of technology together with you.
29
+
30
+ ## Update History
31
+
32
+ > DiffSynth-Studio has undergone major version updates, and some old features are no longer maintained. If you need to use old features, please switch to the [last historical version](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3) before the major version update.
33
+
34
+ > Currently, the development personnel of this project are limited, with most of the work handled by [Artiprocher](https://github.com/Artiprocher). Therefore, the progress of new feature development will be relatively slow, and the speed of responding to and resolving issues is limited. We apologize for this and ask developers to understand.
35
+
36
+ - **December 9, 2025** We release a wild model based on DiffSynth-Studio 2.0: [Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L) (Image-to-LoRA). This model takes an image as input and outputs a LoRA. Although this version still has significant room for improvement in terms of generalization, detail preservation, and other aspects, we are open-sourcing these models to inspire more innovative research.
37
+
38
+ - **December 4, 2025** DiffSynth-Studio 2.0 released! Many new features online
39
+ - [Documentation](/docs/en/README.md) online: Our documentation is still continuously being optimized and updated
40
+ - [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md) module upgraded, supporting layer-level disk offload, releasing both memory and VRAM simultaneously
41
+ - New model support
42
+ - Z-Image Turbo: [Model](https://www.modelscope.ai/models/Tongyi-MAI/Z-Image-Turbo), [Documentation](/docs/en/Model_Details/Z-Image.md), [Code](/examples/z_image/)
43
+ - FLUX.2-dev: [Model](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev), [Documentation](/docs/en/Model_Details/FLUX2.md), [Code](/examples/flux2/)
44
+ - Training framework upgrade
45
+ - [Split Training](/docs/zh/Training/Split_Training.md): Supports automatically splitting the training process into two stages: data processing and training (even for training ControlNet or any other model). Computations that do not require gradient backpropagation, such as text encoding and VAE encoding, are performed during the data processing stage, while other computations are handled during the training stage. Faster speed, less VRAM requirement.
46
+ - [Differential LoRA Training](/docs/zh/Training/Differential_LoRA.md): This is a training technique we used in [ArtAug](https://www.modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1), now available for LoRA training of any model.
47
+ - [FP8 Training](/docs/zh/Training/FP8_Precision.md): FP8 can be applied to any non-training model during training, i.e., models with gradients turned off or gradients that only affect LoRA weights.
48
+
49
+ <details>
50
+ <summary>More</summary>
51
+
52
+ - **November 4, 2025** Supported the [ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B) model, which is trained based on Wan 2.1 and supports generating corresponding actions based on reference videos.
53
+
54
+ - **October 30, 2025** Supported the [meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video) model, which supports text-to-video, image-to-video, and video continuation. This model uses the Wan framework for inference and training in this project.
55
+
56
+ - **October 27, 2025** Supported the [krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video) model, adding another member to the Wan model ecosystem.
57
+
58
+ - **September 23, 2025** [DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster) released! This model was jointly developed and open-sourced by us and Taobao Experience Design Team. Built upon Qwen-Image, the model is specifically designed for e-commerce poster scenarios, supporting precise partition layout control. Please refer to [our sample code](./examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py).
59
+
60
+ - **September 9, 2025** Our training framework supports various training modes. Currently adapted for Qwen-Image, in addition to the standard SFT training mode, Direct Distill is now supported. Please refer to [our sample code](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh). This feature is experimental, and we will continue to improve it to support more comprehensive model training functions.
61
+
62
+ - **August 28, 2025** We support Wan2.2-S2V, an audio-driven cinematic video generation model. See [./examples/wanvideo/](./examples/wanvideo/).
63
+
64
+ - **August 21, 2025** [DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2) released! Compared to the V1 version, the training dataset has been changed to [Qwen-Image-Self-Generated-Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Qwen-Image-Self-Generated-Dataset), so the generated images better conform to Qwen-Image's own image distribution and style. Please refer to [our sample code](./examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py).
65
+
66
+ - **August 21, 2025** We open-sourced the [DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union) structural control LoRA model, adopting the In Context technical route, supporting multiple categories of structural control conditions, including canny, depth, lineart, softedge, normal, and openpose. Please refer to [our sample code](./examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py).
67
+
68
+ - **August 20, 2025** We open-sourced the [DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix) model, improving the editing effect of Qwen-Image-Edit on low-resolution image inputs. Please refer to [our sample code](./examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py)
69
+
70
+ - **August 19, 2025** 🔥 Qwen-Image-Edit open-sourced, welcome a new member to the image editing model family!
71
+
72
+ - **August 18, 2025** We trained and open-sourced the Qwen-Image inpainting ControlNet model [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint). The model structure adopts a lightweight design. Please refer to [our sample code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py).
73
+
74
+ - **August 15, 2025** We open-sourced the [Qwen-Image-Self-Generated-Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Qwen-Image-Self-Generated-Dataset) dataset. This is an image dataset generated using the Qwen-Image model, containing 160,000 `1024 x 1024` images. It includes general, English text rendering, and Chinese text rendering subsets. We provide annotations for image descriptions, entities, and structural control images for each image. Developers can use this dataset to train Qwen-Image models' ControlNet and EliGen models. We aim to promote technological development through open-sourcing!
75
+
76
+ - **August 13, 2025** We trained and open-sourced the Qwen-Image ControlNet model [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth). The model structure adopts a lightweight design. Please refer to [our sample code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py).
77
+
78
+ - **August 12, 2025** We trained and open-sourced the Qwen-Image ControlNet model [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny). The model structure adopts a lightweight design. Please refer to [our sample code](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py).
79
+
80
+ - **August 11, 2025** We open-sourced the distilled acceleration model [DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA) for Qwen-Image, following the same training process as [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full), but the model structure has been modified to LoRA, thus being better compatible with other open-source ecosystem models.
81
+
82
+ - **August 7, 2025** We open-sourced the entity control LoRA model [DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen) for Qwen-Image. Qwen-Image-EliGen can achieve entity-level controlled text-to-image generation. Technical details can be found in [the paper](https://arxiv.org/abs/2501.01097). Training dataset: [EliGenTrainSet](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet).
83
+
84
+ - **August 5, 2025** We open-sourced the distilled acceleration model [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full) for Qwen-Image, achieving approximately 5x acceleration.
85
+
86
+ - **August 4, 2025** 🔥 Qwen-Image open-sourced, welcome a new member to the image generation model family!
87
+
88
+ - **August 1, 2025** [FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev) open-sourced, a text-to-image model focused on aesthetic photography. We provided comprehensive support in a timely manner, including low VRAM layer-by-layer offload, LoRA training, and full training. For more details, please refer to [./examples/flux/](./examples/flux/).
89
+
90
+ - **July 28, 2025** Wan 2.2 open-sourced. We provided comprehensive support in a timely manner, including low VRAM layer-by-layer offload, FP8 quantization, sequence parallelism, LoRA training, and full training. For more details, please refer to [./examples/wanvideo/](./examples/wanvideo/).
91
+
92
+ - **July 11, 2025** We propose Nexus-Gen, a unified framework that combines the language reasoning capabilities of Large Language Models (LLMs) with the image generation capabilities of diffusion models. This framework supports seamless image understanding, generation, and editing tasks.
93
+ - Paper: [Nexus-Gen: Unified Image Understanding, Generation, and Editing via Prefilled Autoregression in Shared Embedding Space](https://arxiv.org/pdf/2504.21356)
94
+ - GitHub Repository: https://github.com/modelscope/Nexus-Gen
95
+ - Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2), [HuggingFace](https://huggingface.co/modelscope/Nexus-GenV2)
96
+ - Training Dataset: [ModelScope Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Nexus-Gen-Training-Dataset)
97
+ - Online Experience: [ModelScope Nexus-Gen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/Nexus-Gen)
98
+
99
+ - **June 15, 2025** ModelScope's official evaluation framework [EvalScope](https://github.com/modelscope/evalscope) now supports text-to-image generation evaluation. Please refer to the [best practices](https://evalscope.readthedocs.io/zh-cn/latest/best_practice/t2i_eval.html) guide to try it out.
100
+
101
+ - **March 25, 2025** Our new open-source project [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine) is now open-sourced! Focused on stable model deployment, targeting industry, providing better engineering support, higher computational performance, and more stable features.
102
+
103
+ - **March 31, 2025** We support InfiniteYou, a face feature preservation method for FLUX. More details can be found in [./examples/InfiniteYou/](./examples/InfiniteYou/).
104
+
105
+ - **March 13, 2025** We support HunyuanVideo-I2V, the image-to-video generation version of Tencent's open-source HunyuanVideo. More details can be found in [./examples/HunyuanVideo/](./examples/HunyuanVideo/).
106
+
107
+ - **February 25, 2025** We support Wan-Video, a series of state-of-the-art video synthesis models open-sourced by Alibaba. See [./examples/wanvideo/](./examples/wanvideo/).
108
+
109
+ - **February 17, 2025** We support [StepVideo](https://modelscope.cn/models/stepfun-ai/stepvideo-t2v/summary)! Advanced video synthesis model! See [./examples/stepvideo](./examples/stepvideo/).
110
+
111
+ - **December 31, 2024** We propose EliGen, a new framework for entity-level controlled text-to-image generation, supplemented with an inpainting fusion pipeline, extending its capabilities to image inpainting tasks. EliGen can seamlessly integrate existing community models such as IP-Adapter and In-Context LoRA, enhancing their versatility. For more details, see [./examples/EntityControl](./examples/EntityControl/).
112
+ - Paper: [EliGen: Entity-Level Controlled Image Generation with Regional Attention](https://arxiv.org/abs/2501.01097)
113
+ - Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen), [HuggingFace](https://huggingface.co/modelscope/EliGen)
114
+ - Online Experience: [ModelScope EliGen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/EliGen)
115
+ - Training Dataset: [EliGen Train Set](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)
116
+
117
+ - **December 19, 2024** We implemented advanced VRAM management for HunyuanVideo, enabling video generation with resolutions of 129x720x1280 on 24GB VRAM or 129x512x384 on just 6GB VRAM. More details can be found in [./examples/HunyuanVideo/](./examples/HunyuanVideo/).
118
+
119
+ - **December 18, 2024** We propose ArtAug, a method to improve text-to-image models through synthesis-understanding interaction. We trained an ArtAug enhancement module for FLUX.1-dev in LoRA format. This model incorporates the aesthetic understanding of Qwen2-VL-72B into FLUX.1-dev, thereby improving the quality of generated images.
120
+ - Paper: https://arxiv.org/abs/2412.12888
121
+ - Example: https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/ArtAug
122
+ - Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1), [HuggingFace](https://huggingface.co/ECNU-CILab/ArtAug-lora-FLUX.1dev-v1)
123
+ - Demo: [ModelScope](https://modelscope.cn/aigc/imageGeneration?tab=advanced&versionId=7228&modelType=LoRA&sdVersion=FLUX_1&modelUrl=modelscope%3A%2F%2FDiffSynth-Studio%2FArtAug-lora-FLUX.1dev-v1%3Frevision%3Dv1.0), HuggingFace (coming soon)
124
+
125
+ - **October 25, 2024** We provide extensive FLUX ControlNet support. This project supports many different ControlNet models and can be freely combined, even if their structures are different. Additionally, ControlNet models are compatible with high-resolution optimization and partition control technologies, enabling very powerful controllable image generation. See [`./examples/ControlNet/`](./examples/ControlNet/).
126
+
127
+ - **October 8, 2024** We released extended LoRAs based on CogVideoX-5B and ExVideo. You can download this model from [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1) or [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1).
128
+
129
+ - **August 22, 2024** This project now supports CogVideoX-5B. See [here](/examples/video_synthesis/). We provide several interesting features for this text-to-video model, including:
130
+ - Text-to-video
131
+ - Video editing
132
+ - Self super-resolution
133
+ - Video interpolation
134
+
135
+ - **August 22, 2024** We implemented an interesting brush feature that supports all text-to-image models. Now you can create stunning images with the assistance of AI using the brush!
136
+ - Use it in our [WebUI](#usage-in-webui).
137
+
138
+ - **August 21, 2024** DiffSynth-Studio now supports FLUX.
139
+ - Enable CFG and high-resolution inpainting to improve visual quality. See [here](/examples/image_synthesis/README.md)
140
+ - LoRA, ControlNet, and other addon models will be released soon.
141
+
142
+ - **June 21, 2024** We propose ExVideo, a post-training fine-tuning technique aimed at enhancing the capabilities of video generation models. We extended Stable Video Diffusion to achieve long video generation of up to 128 frames.
143
+ - [Project Page](https://ecnu-cilab.github.io/ExVideoProjectPage/)
144
+ - Source code has been released in this repository. See [`examples/ExVideo`](./examples/ExVideo/).
145
+ - Model has been released at [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1) and [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1).
146
+ - Technical report has been released at [arXiv](https://arxiv.org/abs/2406.14130).
147
+ - You can try ExVideo in this [demo](https://huggingface.co/spaces/modelscope/ExVideo-SVD-128f-v1)!
148
+
149
+ - **June 13, 2024** DiffSynth Studio has migrated to ModelScope. The development team has also transitioned from "me" to "us". Of course, I will still participate in subsequent development and maintenance work.
150
+
151
+ - **January 29, 2024** We propose Diffutoon, an excellent cartoon coloring solution.
152
+ - [Project Page](https://ecnu-cilab.github.io/DiffutoonProjectPage/)
153
+ - Source code has been released in this project.
154
+ - Technical report (IJCAI 2024) has been released at [arXiv](https://arxiv.org/abs/2401.16224).
155
+
156
+ - **December 8, 2023** We decided to initiate a new project aimed at unleashing the potential of diffusion models, especially in video synthesis. The development work of this project officially began.
157
+
158
+ - **November 15, 2023** We propose FastBlend, a powerful video deflickering algorithm.
159
+ - sd-webui extension has been released at [GitHub](https://github.com/Artiprocher/sd-webui-fastblend).
160
+ - Demonstration videos have been showcased on Bilibili, including three tasks:
161
+ - [Video Deflickering](https://www.bilibili.com/video/BV1d94y1W7PE)
162
+ - [Video Interpolation](https://www.bilibili.com/video/BV1Lw411m71p)
163
+ - [Image-Driven Video Rendering](https://www.bilibili.com/video/BV1RB4y1Z7LF)
164
+ - Technical report has been released at [arXiv](https://arxiv.org/abs/2311.09265).
165
+ - Unofficial ComfyUI extensions developed by other users have been released at [GitHub](https://github.com/AInseven/ComfyUI-fastblend).
166
+
167
+ - **October 1, 2023** We released an early version of the project named FastSDXL. This was an initial attempt to build a diffusion engine.
168
+ - Source code has been released at [GitHub](https://github.com/Artiprocher/FastSDXL).
169
+ - FastSDXL includes a trainable OLSS scheduler to improve efficiency.
170
+ - The original repository of OLSS is located [here](https://github.com/alibaba/EasyNLP/tree/master/diffusion/olss_scheduler).
171
+ - Technical report (CIKM 2023) has been released at [arXiv](https://arxiv.org/abs/2305.14677).
172
+ - Demonstration video has been released at [Bilibili](https://www.bilibili.com/video/BV1w8411y7uj).
173
+ - Since OLSS requires additional training, we did not implement it in this project.
174
+
175
+ - **August 29, 2023** We propose DiffSynth, a video synthesis framework.
176
+ - [Project Page](https://ecnu-cilab.github.io/DiffSynth.github.io/).
177
+ - Source code has been released at [EasyNLP](https://github.com/alibaba/EasyNLP/tree/master/diffusion/DiffSynth).
178
+ - Technical report (ECML PKDD 2024) has been released at [arXiv](https://arxiv.org/abs/2308.03463).
179
+
180
+ </details>
181
+
182
+ ## Installation
183
+
184
+ Install from source (recommended):
185
+
186
+ ```
187
+ git clone https://github.com/modelscope/DiffSynth-Studio.git
188
+ cd DiffSynth-Studio
189
+ pip install -e .
190
+ ```
191
+
192
+ <details>
193
+ <summary>Other installation methods</summary>
194
+
195
+ Install from PyPI (version updates may be delayed; for latest features, install from source)
196
+
197
+ ```
198
+ pip install diffsynth
199
+ ```
200
+
201
+ If you meet problems during installation, they might be caused by upstream dependencies. Please check the docs of these packages:
202
+
203
+ * [torch](https://pytorch.org/get-started/locally/)
204
+ * [sentencepiece](https://github.com/google/sentencepiece)
205
+ * [cmake](https://cmake.org)
206
+ * [cupy](https://docs.cupy.dev/en/stable/install.html)
207
+
208
+ </details>
209
+
210
+ ## Basic Framework
211
+
212
+ DiffSynth-Studio redesigns the inference and training pipelines for mainstream Diffusion models (including FLUX, Wan, etc.), enabling efficient memory management and flexible model training.
213
+
214
+ <details>
215
+ <summary>Environment Variable Configuration</summary>
216
+
217
+ > Before running model inference or training, you can configure settings such as the model download source via [environment variables](/docs/en/Pipeline_Usage/Environment_Variables.md).
218
+ >
219
+ > By default, this project downloads models from ModelScope. For users outside China, you can configure the system to download models from the ModelScope international site as follows:
220
+ >
221
+ > ```python
222
+ > import os
223
+ > os.environ["MODELSCOPE_DOMAIN"] = "www.modelscope.ai"
224
+ > ```
225
+ >
226
+ > To download models from other sources, please modify the environment variable [DIFFSYNTH_DOWNLOAD_SOURCE](/docs/en/Pipeline_Usage/Environment_Variables.md#diffsynth_download_source).
227
+
228
+ </details>
229
+
230
+ ### Image Synthesis
231
+
232
+ ![Image](https://github.com/user-attachments/assets/c01258e2-f251-441a-aa1e-ebb22f02594d)
233
+
234
+ #### Z-Image: [/docs/en/Model_Details/Z-Image.md](/docs/en/Model_Details/Z-Image.md)
235
+
236
+ <details>
237
+
238
+ <summary>Quick Start</summary>
239
+
240
+ Running the following code will quickly load the [Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo) model for inference. FP8 quantization significantly degrades image quality, so we do not recommend enabling any quantization for the Z-Image Turbo model. CPU offloading is recommended, and the model can run with as little as 8 GB of GPU memory.
241
+
242
+ ```python
243
+ from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig
244
+ import torch
245
+
246
+ vram_config = {
247
+ "offload_dtype": torch.bfloat16,
248
+ "offload_device": "cpu",
249
+ "onload_dtype": torch.bfloat16,
250
+ "onload_device": "cpu",
251
+ "preparing_dtype": torch.bfloat16,
252
+ "preparing_device": "cuda",
253
+ "computation_dtype": torch.bfloat16,
254
+ "computation_device": "cuda",
255
+ }
256
+ pipe = ZImagePipeline.from_pretrained(
257
+ torch_dtype=torch.bfloat16,
258
+ device="cuda",
259
+ model_configs=[
260
+ ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors", **vram_config),
261
+ ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
262
+ ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
263
+ ],
264
+ tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
265
+ vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
266
+ )
267
+ prompt = "Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights."
268
+ image = pipe(prompt=prompt, seed=42, rand_device="cuda")
269
+ image.save("image.jpg")
270
+ ```
271
+
272
+ </details>
273
+
274
+ <details>
275
+
276
+ <summary>Examples</summary>
277
+
278
+ Example code for Z-Image is available at: [/examples/z_image/](/examples/z_image/)
279
+
280
+ | Model ID | Inference | Low-VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
281
+ |-|-|-|-|-|-|-|
282
+ |[Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo)|[code](/examples/z_image/model_inference/Z-Image-Turbo.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py)|
283
+
284
+ </details>
285
+
286
+ #### FLUX.2: [/docs/en/Model_Details/FLUX2.md](/docs/en/Model_Details/FLUX2.md)
287
+
288
+ <details>
289
+
290
+ <summary>Quick Start</summary>
291
+
292
+ Running the following code will quickly load the [black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev) model for inference. VRAM management is enabled, and the framework automatically loads model parameters based on available GPU memory. The model can run with as little as 10 GB of VRAM.
293
+
294
+ ```python
295
+ from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
296
+ import torch
297
+
298
+ vram_config = {
299
+ "offload_dtype": "disk",
300
+ "offload_device": "disk",
301
+ "onload_dtype": torch.float8_e4m3fn,
302
+ "onload_device": "cpu",
303
+ "preparing_dtype": torch.float8_e4m3fn,
304
+ "preparing_device": "cuda",
305
+ "computation_dtype": torch.bfloat16,
306
+ "computation_device": "cuda",
307
+ }
308
+ pipe = Flux2ImagePipeline.from_pretrained(
309
+ torch_dtype=torch.bfloat16,
310
+ device="cuda",
311
+ model_configs=[
312
+ ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
313
+ ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="transformer/*.safetensors", **vram_config),
314
+ ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
315
+ ],
316
+ tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="tokenizer/"),
317
+ vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
318
+ )
319
+ prompt = "High resolution. A dreamy underwater portrait of a serene young woman in a flowing blue dress. Her hair floats softly around her face, strands delicately suspended in the water. Clear, shimmering light filters through, casting gentle highlights, while tiny bubbles rise around her. Her expression is calm, her features finely detailed—creating a tranquil, ethereal scene."
320
+ image = pipe(prompt, seed=42, rand_device="cuda", num_inference_steps=50)
321
+ image.save("image.jpg")
322
+ ```
323
+
324
+ </details>
325
+
326
+ <details>
327
+
328
+ <summary>Examples</summary>
329
+
330
+ Example code for FLUX.2 is available at: [/examples/flux2/](/examples/flux2/)
331
+
332
+ | Model ID | Inference | Low-VRAM Inference | LoRA Training | LoRA Training Validation |
333
+ |-|-|-|-|-|
334
+ |[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](/examples/flux2/model_inference/FLUX.2-dev.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|[code](/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)|
335
+
336
+ </details>
337
+
338
+ #### Qwen-Image: [/docs/en/Model_Details/Qwen-Image.md](/docs/en/Model_Details/Qwen-Image.md)
339
+
340
+ <details>
341
+
342
+ <summary>Quick Start</summary>
343
+
344
+ Running the following code will quickly load the [Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image) model for inference. VRAM management is enabled, and the framework automatically adjusts model parameter loading based on available GPU memory. The model can run with as little as 8 GB of VRAM.
345
+
346
+ ```python
347
+ from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
348
+ import torch
349
+
350
+ vram_config = {
351
+ "offload_dtype": "disk",
352
+ "offload_device": "disk",
353
+ "onload_dtype": torch.float8_e4m3fn,
354
+ "onload_device": "cpu",
355
+ "preparing_dtype": torch.float8_e4m3fn,
356
+ "preparing_device": "cuda",
357
+ "computation_dtype": torch.bfloat16,
358
+ "computation_device": "cuda",
359
+ }
360
+ pipe = QwenImagePipeline.from_pretrained(
361
+ torch_dtype=torch.bfloat16,
362
+ device="cuda",
363
+ model_configs=[
364
+ ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
365
+ ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config),
366
+ ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
367
+ ],
368
+ tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
369
+ vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
370
+ )
371
+ prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
372
+ image = pipe(prompt, seed=0, num_inference_steps=40)
373
+ image.save("image.jpg")
374
+ ```
375
+
376
+ </details>
377
+
378
+ <details>
379
+
380
+ <summary>Model Lineage</summary>
381
+
382
+ ```mermaid
383
+ graph LR;
384
+ Qwen/Qwen-Image-->Qwen/Qwen-Image-Edit;
385
+ Qwen/Qwen-Image-Edit-->Qwen/Qwen-Image-Edit-2509;
386
+ Qwen/Qwen-Image-->EliGen-Series;
387
+ EliGen-Series-->DiffSynth-Studio/Qwen-Image-EliGen;
388
+ DiffSynth-Studio/Qwen-Image-EliGen-->DiffSynth-Studio/Qwen-Image-EliGen-V2;
389
+ EliGen-Series-->DiffSynth-Studio/Qwen-Image-EliGen-Poster;
390
+ Qwen/Qwen-Image-->Distill-Series;
391
+ Distill-Series-->DiffSynth-Studio/Qwen-Image-Distill-Full;
392
+ Distill-Series-->DiffSynth-Studio/Qwen-Image-Distill-LoRA;
393
+ Qwen/Qwen-Image-->ControlNet-Series;
394
+ ControlNet-Series-->Blockwise-ControlNet-Series;
395
+ Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny;
396
+ Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth;
397
+ Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint;
398
+ ControlNet-Series-->DiffSynth-Studio/Qwen-Image-In-Context-Control-Union;
399
+ Qwen/Qwen-Image-->DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix;
400
+ ```
401
+
402
+ </details>
403
+
404
+ <details>
405
+
406
+ <summary>Examples</summary>
407
+
408
+ Example code for Qwen-Image is available at: [/examples/qwen_image/](/examples/qwen_image/)
409
+
410
+ | Model ID | Inference | Low-VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
411
+ |-|-|-|-|-|-|-|
412
+ |[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](/examples/qwen_image/model_inference/Qwen-Image.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image.py)|
413
+ |[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)|
414
+ |[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)|
415
+ |[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
416
+ |[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
417
+ |[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)|
418
+ |[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](/examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-Full.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py)|
419
+ |[DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA)|[code](/examples/qwen_image/model_inference/Qwen-Image-Distill-LoRA.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-LoRA.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-LoRA.py)|
420
+ |[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py)|
421
+ |[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py)|
422
+ |[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|
423
+ |[DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union)|[code](/examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py)|
424
+ |[DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py)|-|-|-|-|
425
+ |[DiffSynth-Studio/Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L)|[code](/examples/qwen_image/model_inference/Qwen-Image-i2L.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-i2L.py)|-|-|-|-|
426
+
427
+ </details>
428
+
429
+ #### FLUX.1: [/docs/en/Model_Details/FLUX.md](/docs/en/Model_Details/FLUX.md)
430
+
431
+ <details>
432
+
433
+ <summary>Quick Start</summary>
434
+
435
+ Running the following code will quickly load the [black-forest-labs/FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev) model for inference. VRAM management is enabled, and the framework automatically adjusts model parameter loading based on available GPU memory. The model can run with as little as 8 GB of VRAM.
436
+
437
+ ```python
438
+ import torch
439
+ from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig
440
+
441
+ vram_config = {
442
+ "offload_dtype": torch.float8_e4m3fn,
443
+ "offload_device": "cpu",
444
+ "onload_dtype": torch.float8_e4m3fn,
445
+ "onload_device": "cpu",
446
+ "preparing_dtype": torch.float8_e4m3fn,
447
+ "preparing_device": "cuda",
448
+ "computation_dtype": torch.bfloat16,
449
+ "computation_device": "cuda",
450
+ }
451
+ pipe = FluxImagePipeline.from_pretrained(
452
+ torch_dtype=torch.bfloat16,
453
+ device="cuda",
454
+ model_configs=[
455
+ ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors", **vram_config),
456
+ ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
457
+ ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors", **vram_config),
458
+ ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", **vram_config),
459
+ ],
460
+ vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 1,
461
+ )
462
+ prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her."
463
+ image = pipe(prompt=prompt, seed=0)
464
+ image.save("image.jpg")
465
+ ```
466
+
467
+ </details>
468
+
469
+ <details>
470
+
471
+ <summary>Model Lineage</summary>
472
+
473
+ ```mermaid
474
+ graph LR;
475
+ FLUX.1-Series-->black-forest-labs/FLUX.1-dev;
476
+ FLUX.1-Series-->black-forest-labs/FLUX.1-Krea-dev;
477
+ FLUX.1-Series-->black-forest-labs/FLUX.1-Kontext-dev;
478
+ black-forest-labs/FLUX.1-dev-->FLUX.1-dev-ControlNet-Series;
479
+ FLUX.1-dev-ControlNet-Series-->alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta;
480
+ FLUX.1-dev-ControlNet-Series-->InstantX/FLUX.1-dev-Controlnet-Union-alpha;
481
+ FLUX.1-dev-ControlNet-Series-->jasperai/Flux.1-dev-Controlnet-Upscaler;
482
+ black-forest-labs/FLUX.1-dev-->InstantX/FLUX.1-dev-IP-Adapter;
483
+ black-forest-labs/FLUX.1-dev-->ByteDance/InfiniteYou;
484
+ black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/Eligen;
485
+ black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev;
486
+ black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev;
487
+ black-forest-labs/FLUX.1-dev-->ostris/Flex.2-preview;
488
+ black-forest-labs/FLUX.1-dev-->stepfun-ai/Step1X-Edit;
489
+ Qwen/Qwen2.5-VL-7B-Instruct-->stepfun-ai/Step1X-Edit;
490
+ black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/Nexus-GenV2;
491
+ Qwen/Qwen2.5-VL-7B-Instruct-->DiffSynth-Studio/Nexus-GenV2;
492
+ ```
493
+
494
+ </details>
495
+
496
+ <details>
497
+
498
+ <summary>Examples</summary>
499
+
500
+ Example code for FLUX.1 is available at: [/examples/flux/](/examples/flux/)
501
+
502
+ | Model ID | Extra Args | Inference | Low-VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
503
+ |-|-|-|-|-|-|-|-|
504
+ |[black-forest-labs/FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev)||[code](/examples/flux/model_inference/FLUX.1-dev.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev.py)|[code](/examples/flux/model_training/full/FLUX.1-dev.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev.py)|
505
+ |[black-forest-labs/FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev)||[code](/examples/flux/model_inference/FLUX.1-Krea-dev.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-Krea-dev.py)|[code](/examples/flux/model_training/full/FLUX.1-Krea-dev.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-Krea-dev.py)|[code](/examples/flux/model_training/lora/FLUX.1-Krea-dev.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-Krea-dev.py)|
506
+ |[black-forest-labs/FLUX.1-Kontext-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev)|`kontext_images`|[code](/examples/flux/model_inference/FLUX.1-Kontext-dev.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-Kontext-dev.py)|[code](/examples/flux/model_training/full/FLUX.1-Kontext-dev.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-Kontext-dev.py)|[code](/examples/flux/model_training/lora/FLUX.1-Kontext-dev.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-Kontext-dev.py)|
507
+ |[alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta](https://www.modelscope.cn/models/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta)|`controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|
508
+ |[InstantX/FLUX.1-dev-Controlnet-Union-alpha](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-Controlnet-Union-alpha)|`controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Union-alpha.py)|
509
+ |[jasperai/Flux.1-dev-Controlnet-Upscaler](https://www.modelscope.cn/models/jasperai/Flux.1-dev-Controlnet-Upscaler)|`controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Upscaler.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Upscaler.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Upscaler.py)|
510
+ |[InstantX/FLUX.1-dev-IP-Adapter](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-IP-Adapter)|`ipadapter_images`, `ipadapter_scale`|[code](/examples/flux/model_inference/FLUX.1-dev-IP-Adapter.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-IP-Adapter.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-IP-Adapter.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-IP-Adapter.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-IP-Adapter.py)|
511
+ |[ByteDance/InfiniteYou](https://www.modelscope.cn/models/ByteDance/InfiniteYou)|`infinityou_id_image`, `infinityou_guidance`, `controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-InfiniteYou.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-InfiniteYou.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-InfiniteYou.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-InfiniteYou.py)|
512
+ |[DiffSynth-Studio/Eligen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen)|`eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint`|[code](/examples/flux/model_inference/FLUX.1-dev-EliGen.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py)|-|-|[code](/examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-EliGen.py)|
513
+ |[DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev](https://www.modelscope.cn/models/DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev)|`lora_encoder_inputs`, `lora_encoder_scale`|[code](/examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-LoRA-Encoder.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py)|-|-|
514
+ |[DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev)||[code](/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py)|-|-|-|-|-|
515
+ |[stepfun-ai/Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit)|`step1x_reference_image`|[code](/examples/flux/model_inference/Step1X-Edit.py)|[code](/examples/flux/model_inference_low_vram/Step1X-Edit.py)|[code](/examples/flux/model_training/full/Step1X-Edit.sh)|[code](/examples/flux/model_training/validate_full/Step1X-Edit.py)|[code](/examples/flux/model_training/lora/Step1X-Edit.sh)|[code](/examples/flux/model_training/validate_lora/Step1X-Edit.py)|
516
+ |[ostris/Flex.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview)|`flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop`|[code](/examples/flux/model_inference/FLEX.2-preview.py)|[code](/examples/flux/model_inference_low_vram/FLEX.2-preview.py)|[code](/examples/flux/model_training/full/FLEX.2-preview.sh)|[code](/examples/flux/model_training/validate_full/FLEX.2-preview.py)|[code](/examples/flux/model_training/lora/FLEX.2-preview.sh)|[code](/examples/flux/model_training/validate_lora/FLEX.2-preview.py)|
517
+ |[DiffSynth-Studio/Nexus-GenV2](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2)|`nexus_gen_reference_image`|[code](/examples/flux/model_inference/Nexus-Gen-Editing.py)|[code](/examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py)|[code](/examples/flux/model_training/full/Nexus-Gen.sh)|[code](/examples/flux/model_training/validate_full/Nexus-Gen.py)|[code](/examples/flux/model_training/lora/Nexus-Gen.sh)|[code](/examples/flux/model_training/validate_lora/Nexus-Gen.py)|
518
+
519
+ </details>
520
+
521
+ ### Video Synthesis
522
+
523
+ https://github.com/user-attachments/assets/1d66ae74-3b02-40a9-acc3-ea95fc039314
524
+
525
+ #### Wan: [/docs/en/Model_Details/Wan.md](/docs/en/Model_Details/Wan.md)
526
+
527
+ <details>
528
+
529
+ <summary>Quick Start</summary>
530
+
531
+ Running the following code will quickly load the [Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) model for inference. VRAM management is enabled, and the framework automatically adjusts model parameter loading based on available GPU memory. The model can run with as little as 8 GB of VRAM.
532
+
533
+ ```python
534
+ import torch
535
+ from diffsynth.utils.data import save_video, VideoData
536
+ from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
537
+
538
+ vram_config = {
539
+ "offload_dtype": "disk",
540
+ "offload_device": "disk",
541
+ "onload_dtype": torch.bfloat16,
542
+ "onload_device": "cpu",
543
+ "preparing_dtype": torch.bfloat16,
544
+ "preparing_device": "cuda",
545
+ "computation_dtype": torch.bfloat16,
546
+ "computation_device": "cuda",
547
+ }
548
+ pipe = WanVideoPipeline.from_pretrained(
549
+ torch_dtype=torch.bfloat16,
550
+ device="cuda",
551
+ model_configs=[
552
+ ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", **vram_config),
553
+ ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", **vram_config),
554
+ ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", **vram_config),
555
+ ],
556
+ tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"),
557
+ vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2,
558
+ )
559
+
560
+ video = pipe(
561
+ prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
562
+ negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
563
+ seed=0, tiled=True,
564
+ )
565
+ save_video(video, "video.mp4", fps=15, quality=5)
566
+ ```
567
+
568
+ </details>
569
+
570
+ <details>
571
+
572
+ <summary>Model Lineage</summary>
573
+
574
+ ```mermaid
575
+ graph LR;
576
+ Wan-Series-->Wan2.1-Series;
577
+ Wan-Series-->Wan2.2-Series;
578
+ Wan2.1-Series-->Wan-AI/Wan2.1-T2V-1.3B;
579
+ Wan2.1-Series-->Wan-AI/Wan2.1-T2V-14B;
580
+ Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-I2V-14B-480P;
581
+ Wan-AI/Wan2.1-I2V-14B-480P-->Wan-AI/Wan2.1-I2V-14B-720P;
582
+ Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-FLF2V-14B-720P;
583
+ Wan-AI/Wan2.1-T2V-1.3B-->iic/VACE-Wan2.1-1.3B-Preview;
584
+ iic/VACE-Wan2.1-1.3B-Preview-->Wan-AI/Wan2.1-VACE-1.3B;
585
+ Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-VACE-14B;
586
+ Wan-AI/Wan2.1-T2V-1.3B-->Wan2.1-Fun-1.3B-Series;
587
+ Wan2.1-Fun-1.3B-Series-->PAI/Wan2.1-Fun-1.3B-InP;
588
+ Wan2.1-Fun-1.3B-Series-->PAI/Wan2.1-Fun-1.3B-Control;
589
+ Wan-AI/Wan2.1-T2V-14B-->Wan2.1-Fun-14B-Series;
590
+ Wan2.1-Fun-14B-Series-->PAI/Wan2.1-Fun-14B-InP;
591
+ Wan2.1-Fun-14B-Series-->PAI/Wan2.1-Fun-14B-Control;
592
+ Wan-AI/Wan2.1-T2V-1.3B-->Wan2.1-Fun-V1.1-1.3B-Series;
593
+ Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-Control;
594
+ Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-InP;
595
+ Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera;
596
+ Wan-AI/Wan2.1-T2V-14B-->Wan2.1-Fun-V1.1-14B-Series;
597
+ Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-Control;
598
+ Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-InP;
599
+ Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-Control-Camera;
600
+ Wan-AI/Wan2.1-T2V-1.3B-->DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1;
601
+ Wan-AI/Wan2.1-T2V-14B-->krea/krea-realtime-video;
602
+ Wan-AI/Wan2.1-T2V-14B-->meituan-longcat/LongCat-Video;
603
+ Wan-AI/Wan2.1-I2V-14B-720P-->ByteDance/Video-As-Prompt-Wan2.1-14B;
604
+ Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.2-Animate-14B;
605
+ Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.2-S2V-14B;
606
+ Wan2.2-Series-->Wan-AI/Wan2.2-T2V-A14B;
607
+ Wan2.2-Series-->Wan-AI/Wan2.2-I2V-A14B;
608
+ Wan2.2-Series-->Wan-AI/Wan2.2-TI2V-5B;
609
+ Wan-AI/Wan2.2-T2V-A14B-->Wan2.2-Fun-Series;
610
+ Wan2.2-Fun-Series-->PAI/Wan2.2-VACE-Fun-A14B;
611
+ Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-InP;
612
+ Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-Control;
613
+ Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-Control-Camera;
614
+ ```
615
+
616
+ </details>
617
+
618
+ <details>
619
+
620
+ <summary>Examples</summary>
621
+
622
+ Example code for Wan is available at: [/examples/wanvideo/](/examples/wanvideo/)
623
+
624
+ | Model ID | Extra Args | Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
625
+ |-|-|-|-|-|-|-|
626
+ |[Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)||[code](/examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py)|
627
+ |[Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)||[code](/examples/wanvideo/model_inference/Wan2.1-T2V-14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py)|
628
+ |[Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py)|
629
+ |[Wan-AI/Wan2.1-I2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-720P.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-720P.py)|
630
+ |[Wan-AI/Wan2.1-FLF2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-FLF2V-14B-720P.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py)|
631
+ |[iic/VACE-Wan2.1-1.3B-Preview](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B-Preview.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py)|
632
+ |[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B.py)|
633
+ |[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-VACE-14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py)|
634
+ |[PAI/Wan2.1-Fun-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py)|
635
+ |[PAI/Wan2.1-Fun-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control)|`control_video`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py)|
636
+ |[PAI/Wan2.1-Fun-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-InP.py)|
637
+ |[PAI/Wan2.1-Fun-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control)|`control_video`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-Control.py)|
638
+ |[PAI/Wan2.1-Fun-V1.1-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py)|
639
+ |[PAI/Wan2.1-Fun-V1.1-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py)|
640
+ |[PAI/Wan2.1-Fun-V1.1-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py)|
641
+ |[PAI/Wan2.1-Fun-V1.1-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py)|
642
+ |[PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|
643
+ |[PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|
644
+ |[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](/examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)|
645
+ |[krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video)||[code](/examples/wanvideo/model_inference/krea-realtime-video.py)|[code](/examples/wanvideo/model_training/full/krea-realtime-video.sh)|[code](/examples/wanvideo/model_training/validate_full/krea-realtime-video.py)|[code](/examples/wanvideo/model_training/lora/krea-realtime-video.sh)|[code](/examples/wanvideo/model_training/validate_lora/krea-realtime-video.py)|
646
+ |[meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video)|`longcat_video`|[code](/examples/wanvideo/model_inference/LongCat-Video.py)|[code](/examples/wanvideo/model_training/full/LongCat-Video.sh)|[code](/examples/wanvideo/model_training/validate_full/LongCat-Video.py)|[code](/examples/wanvideo/model_training/lora/LongCat-Video.sh)|[code](/examples/wanvideo/model_training/validate_lora/LongCat-Video.py)|
647
+ |[ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B)|`vap_video`, `vap_prompt`|[code](/examples/wanvideo/model_inference/Video-As-Prompt-Wan2.1-14B.py)|[code](/examples/wanvideo/model_training/full/Video-As-Prompt-Wan2.1-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Video-As-Prompt-Wan2.1-14B.py)|[code](/examples/wanvideo/model_training/lora/Video-As-Prompt-Wan2.1-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Video-As-Prompt-Wan2.1-14B.py)|
648
+ |[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)|
649
+ |[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)|
650
+ |[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py)|
651
+ |[Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B)|`input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video`|[code](/examples/wanvideo/model_inference/Wan2.2-Animate-14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py)|
652
+ |[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](/examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-S2V-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-S2V-14B.py)|
653
+ |[PAI/Wan2.2-VACE-Fun-A14B](https://www.modelscope.cn/models/PAI/Wan2.2-VACE-Fun-A14B)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.2-VACE-Fun-A14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-VACE-Fun-A14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-VACE-Fun-A14B.py)|
654
+ |[PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py)|
655
+ |[PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py)|
656
+ |[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)|
657
+
658
+ </details>
659
+
660
+ ## Innovative Achievements
661
+
662
+ DiffSynth-Studio is not just an engineered model framework, but also an incubator for innovative achievements.
663
+
664
+ <details>
665
+
666
+ <summary>AttriCtrl: Attribute Intensity Control for Image Generation Models</summary>
667
+
668
+ - Paper: [AttriCtrl: Fine-Grained Control of Aesthetic Attribute Intensity in Diffusion Models](https://arxiv.org/abs/2508.02151)
669
+ - Sample Code: [/examples/flux/model_inference/FLUX.1-dev-AttriCtrl.py](/examples/flux/model_inference/FLUX.1-dev-AttriCtrl.py)
670
+ - Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev)
671
+
672
+ |brightness scale = 0.1|brightness scale = 0.3|brightness scale = 0.5|brightness scale = 0.7|brightness scale = 0.9|
673
+ |-|-|-|-|-|
674
+ |![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.1.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.3.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.5.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.7.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.9.jpg)|
675
+
676
+ </details>
677
+
678
+
679
+ <details>
680
+
681
+ <summary>AutoLoRA: Automated LoRA Retrieval and Fusion</summary>
682
+
683
+ - Paper: [AutoLoRA: Automatic LoRA Retrieval and Fine-Grained Gated Fusion for Text-to-Image Generation](https://arxiv.org/abs/2508.02107)
684
+ - Sample Code: [/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py](/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py)
685
+ - Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev)
686
+
687
+ ||[LoRA 1](https://modelscope.cn/models/cancel13/cxsk)|[LoRA 2](https://modelscope.cn/models/wy413928499/xuancai2)|[LoRA 3](https://modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1)|[LoRA 4](https://modelscope.cn/models/hongyanbujian/JPL)|
688
+ |-|-|-|-|-|
689
+ |[LoRA 1](https://modelscope.cn/models/cancel13/cxsk) |![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_0.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_1.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_3.jpg)|
690
+ |[LoRA 2](https://modelscope.cn/models/wy413928499/xuancai2) |![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_1.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_1.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_3.jpg)|
691
+ |[LoRA 3](https://modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1) |![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_2_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_2_3.jpg)|
692
+ |[LoRA 4](https://modelscope.cn/models/hongyanbujian/JPL) |![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_3.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_3.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_2_3.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_3_3.jpg)|
693
+
694
+ </details>
695
+
696
+
697
+ <details>
698
+
699
+ <summary>Nexus-Gen: Unified Architecture for Image Understanding, Generation, and Editing</summary>
700
+
701
+ - Detailed Page: https://github.com/modelscope/Nexus-Gen
702
+ - Paper: [Nexus-Gen: Unified Image Understanding, Generation, and Editing via Prefilled Autoregression in Shared Embedding Space](https://arxiv.org/pdf/2504.21356)
703
+ - Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2), [HuggingFace](https://huggingface.co/modelscope/Nexus-GenV2)
704
+ - Dataset: [ModelScope Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Nexus-Gen-Training-Dataset)
705
+ - Online Experience: [ModelScope Nexus-Gen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/Nexus-Gen)
706
+
707
+ ![](https://github.com/modelscope/Nexus-Gen/raw/main/assets/illustrations/gen_edit.jpg)
708
+
709
+ </details>
710
+
711
+
712
+ <details>
713
+
714
+ <summary>ArtAug: Aesthetic Enhancement for Image Generation Models</summary>
715
+
716
+ - Detailed Page: [./examples/ArtAug/](./examples/ArtAug/)
717
+ - Paper: [ArtAug: Enhancing Text-to-Image Generation through Synthesis-Understanding Interaction](https://arxiv.org/abs/2412.12888)
718
+ - Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1), [HuggingFace](https://huggingface.co/ECNU-CILab/ArtAug-lora-FLUX.1dev-v1)
719
+ - Online Experience: [ModelScope AIGC Tab](https://www.modelscope.cn/aigc/imageGeneration?tab=advanced&versionId=7228&modelType=LoRA&sdVersion=FLUX_1&modelUrl=modelscope%3A%2F%2FDiffSynth-Studio%2FArtAug-lora-FLUX.1dev-v1%3Frevision%3Dv1.0)
720
+
721
+ |FLUX.1-dev|FLUX.1-dev + ArtAug LoRA|
722
+ |-|-|
723
+ |![image_1_base](https://github.com/user-attachments/assets/e1d5c505-b423-45fe-be01-25c2758f5417)|![image_1_enhance](https://github.com/user-attachments/assets/335908e3-d0bd-41c2-9d99-d10528a2d719)|
724
+
725
+ </details>
726
+
727
+
728
+ <details>
729
+
730
+ <summary>EliGen: Precise Image Partition Control</summary>
731
+
732
+ - Paper: [EliGen: Entity-Level Controlled Image Generation with Regional Attention](https://arxiv.org/abs/2501.01097)
733
+ - Sample Code: [/examples/flux/model_inference/FLUX.1-dev-EliGen.py](/examples/flux/model_inference/FLUX.1-dev-EliGen.py)
734
+ - Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen), [HuggingFace](https://huggingface.co/modelscope/EliGen)
735
+ - Online Experience: [ModelScope EliGen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/EliGen)
736
+ - Dataset: [EliGen Train Set](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)
737
+
738
+ |Entity Control Region|Generated Image|
739
+ |-|-|
740
+ |![eligen_example_2_mask_0](https://github.com/user-attachments/assets/1c6d9445-5022-4d91-ad2e-dc05321883d1)|![eligen_example_2_0](https://github.com/user-attachments/assets/86739945-cb07-4a49-b3b3-3bb65c90d14f)|
741
+
742
+ </details>
743
+
744
+
745
+ <details>
746
+
747
+ <summary>ExVideo: Extended Training for Video Generation Models</summary>
748
+
749
+ - Project Page: [Project Page](https://ecnu-cilab.github.io/ExVideoProjectPage/)
750
+ - Paper: [ExVideo: Extending Video Diffusion Models via Parameter-Efficient Post-Tuning](https://arxiv.org/abs/2406.14130)
751
+ - Sample Code: Please refer to the [older version](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3/examples/ExVideo)
752
+ - Model: [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1), [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1)
753
+
754
+ https://github.com/modelscope/DiffSynth-Studio/assets/35051019/d97f6aa9-8064-4b5b-9d49-ed6001bb9acc
755
+
756
+ </details>
757
+
758
+
759
+ <details>
760
+
761
+ <summary>Diffutoon: High-Resolution Anime-Style Video Rendering</summary>
762
+
763
+ - Project Page: [Project Page](https://ecnu-cilab.github.io/DiffutoonProjectPage/)
764
+ - Paper: [Diffutoon: High-Resolution Editable Toon Shading via Diffusion Models](https://arxiv.org/abs/2401.16224)
765
+ - Sample Code: Please refer to the [older version](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3/examples/Diffutoon)
766
+
767
+ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/b54c05c5-d747-4709-be5e-b39af82404dd
768
+
769
+ </details>
770
+
771
+
772
+ <details>
773
+
774
+ <summary>DiffSynth: The Original Version of This Project</summary>
775
+
776
+ - Project Page: [Project Page](https://ecnu-cilab.github.io/DiffSynth.github.io/)
777
+ - Paper: [DiffSynth: Latent In-Iteration Deflickering for Realistic Video Synthesis](https://arxiv.org/abs/2308.03463)
778
+ - Sample Code: Please refer to the [older version](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3/examples/diffsynth)
779
+
780
+ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-4481-b79f-0c3a7361a1ea
781
+
782
+ </details>
783
+
README_zh.md ADDED
@@ -0,0 +1,784 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DiffSynth-Studio
2
+
3
+ <a href="https://github.com/modelscope/DiffSynth-Studio"><img src=".github/workflows/logo.gif" title="Logo" style="max-width:100%;" width="55" /></a> <a href="https://trendshift.io/repositories/10946" target="_blank"><img src="https://trendshift.io/api/badge/repositories/10946" alt="modelscope%2FDiffSynth-Studio | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a></p>
4
+
5
+ [![PyPI](https://img.shields.io/pypi/v/DiffSynth)](https://pypi.org/project/DiffSynth/)
6
+ [![license](https://img.shields.io/github/license/modelscope/DiffSynth-Studio.svg)](https://github.com/modelscope/DiffSynth-Studio/blob/master/LICENSE)
7
+ [![open issues](https://isitmaintained.com/badge/open/modelscope/DiffSynth-Studio.svg)](https://github.com/modelscope/DiffSynth-Studio/issues)
8
+ [![GitHub pull-requests](https://img.shields.io/github/issues-pr/modelscope/DiffSynth-Studio.svg)](https://GitHub.com/modelscope/DiffSynth-Studio/pull/)
9
+ [![GitHub latest commit](https://badgen.net/github/last-commit/modelscope/DiffSynth-Studio)](https://GitHub.com/modelscope/DiffSynth-Studio/commit/)
10
+
11
+ [Switch to English](./README.md)
12
+
13
+ ## 简介
14
+
15
+ 欢迎来到 Diffusion 模型的魔法世界!DiffSynth-Studio 是由[魔搭社区](https://www.modelscope.cn/)团队开发和维护的开源 Diffusion 模型引擎。我们期望以框架建设孵化技术创新,凝聚开源社区的力量,探索生成式模型技术的边界!
16
+
17
+ DiffSynth 目前包括两个开源项目:
18
+ * [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio): 聚焦于激进的技术探索,面向学术界,提供更前沿的模型能力支持。
19
+ * [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine): 聚焦于稳定的模型部署,面向工业界,提供更高的计算性能与更稳定的功能。
20
+
21
+ [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio) 与 [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine) 是魔搭社区 AIGC 专区的核心引擎,欢迎体验我们精心打造的产品化功能:
22
+
23
+ * 魔搭社区 AIGC 专区 (面向中国用户): https://modelscope.cn/aigc/home
24
+ * ModelScope Civision (for global users): https://modelscope.ai/civision/home
25
+
26
+ > DiffSynth-Studio 文档:[中文版](/docs/zh/README.md)、[English version](/docs/en/README.md)
27
+
28
+ 我们相信,一个完善的开源代码框架能够降低技术探索的门槛,我们基于这个代码库搞出了不少[有意思的技术](#创新成果)。或许你也有许多天马行空的构想,借助 DiffSynth-Studio,你可以快速实现这些想法。为此,我们为开发者准备了详细的文档,我们希望通过这些文档,帮助开发者理解 Diffusion 模型的原理,更期待与你一同拓展技术的边界。
29
+
30
+ ## 更新历史
31
+
32
+ > DiffSynth-Studio 经历了大版本更新,部分旧功能已停止维护,如需使用旧版功能,请切换到大版本更新前的[最后一个历史版本](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3)。
33
+
34
+ > 目前本项目的开发人员有限,大部分工作由 [Artiprocher](https://github.com/Artiprocher) 负责,因此新功能的开发进展会比较缓慢,issue 的回复和解决速度有限,我们对此感到非常抱歉,请各位开发者理解。
35
+
36
+ - **2025年12月9日** 我们基于 DiffSynth-Studio 2.0 训练了一个疯狂的模型:[Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L)(Image to LoRA)。这一模型以图像为输入,以 LoRA 为输出。尽管这个版本的模型在泛化能力、细节保持能力等方面还有很大改进空间,我们将这些模型开源,以启发更多创新性的研究工作。
37
+
38
+ - **2025年12月4日** DiffSynth-Studio 2.0 发布!众多新功能上线
39
+ - [文档](/docs/zh/README.md)上线:我们的文档还在持续优化更新中
40
+ - [显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md)模块升级,支持 Layer 级别的 Disk Offload,同时释放内存与显存
41
+ - 新模型支持
42
+ - Z-Image Turbo: [模型](https://www.modelscope.ai/models/Tongyi-MAI/Z-Image-Turbo)、[文档](/docs/zh/Model_Details/Z-Image.md)、[代码](/examples/z_image/)
43
+ - FLUX.2-dev: [模型](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)、[文档](/docs/zh/Model_Details/FLUX2.md)、[代码](/examples/flux2/)
44
+ - 训练框架升级
45
+ - [拆分训练](/docs/zh/Training/Split_Training.md):支持自动化地将训练过程拆分为数据处理和训练两阶段(即使训练的是 ControlNet 或其他任意模型),在数据处理阶段进行文本编码、VAE 编码等不需要梯度回传的计算,在训练阶段处理其他计算。速度更快,显存需求更少。
46
+ - [差分 LoRA 训练](/docs/zh/Training/Differential_LoRA.md):这是我们曾在 [ArtAug](https://www.modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1) 中使用的训练技术,目前已可用于任意模型的 LoRA 训练。
47
+ - [FP8 训练](/docs/zh/Training/FP8_Precision.md):FP8 在训练中支持应用到任意非训练模型,即梯度关闭或者梯度仅影响 LoRA 权重的模型。
48
+
49
+ <details>
50
+ <summary>更多</summary>
51
+
52
+ - **2025年11月4日** 支持了 [ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B) 模型,该模型基于 Wan 2.1 训练,支持根据参考视频生成相应的动作。
53
+
54
+ - **2025年10月30日** 支持了 [meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video) 模型,该模型支持文生视频、图生视频、视频续写。这个模型在本项目中沿用 Wan 的框架进行推理和训练。
55
+
56
+ - **2025年10月27日** 支持了 [krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video) 模型,Wan 模型生态再添一员。
57
+
58
+ - **2025年9月23日** [DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster) 发布!本模型由我们与淘天体验设计团队联合研发并开源。模型基于 Qwen-Image 构建,专为电商海报场景设计,支持精确的分区布局控制。 请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)。
59
+
60
+ - **2025年9月9日** 我们的训练框架支持了多种训练模式,目前已适配 Qwen-Image,除标准 SFT 训练模式外,已支持 Direct Distill,请参考[我们的示例代码](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh)。这项功能是实验性的,我们将会继续完善已支持更全面的模型训练功能。
61
+
62
+ - **2025年8月28日** 我们支持了Wan2.2-S2V,一个音频驱动的电影级视频生成模型。请参见[./examples/wanvideo/](./examples/wanvideo/)。
63
+
64
+ - **2025年8月21日** [DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2) 发布!相比于 V1 版本,训练数据集变为 [Qwen-Image-Self-Generated-Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Qwen-Image-Self-Generated-Dataset),因此,生成的图像更符合 Qwen-Image 本身的图像分布和风格。 请参考[我们的示例代码](./examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)。
65
+
66
+ - **2025年8月21日** 我们开源了 [DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union) 结构控制 LoRA 模型,采用 In Context 的技术路线,支持多种类别的结构控制条件,包括 canny, depth, lineart, softedge, normal, openpose。 请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py)。
67
+
68
+ - **2025年8月20日** 我们开源了 [DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix) 模型,提升了 Qwen-Image-Edit 对低分辨率图像输入的编辑效果。请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py)
69
+
70
+ - **2025年8月19日** 🔥 Qwen-Image-Edit 开源,欢迎图像编辑模型新成员!
71
+
72
+ - **2025年8月18日** 我们训练并开源了 Qwen-Image 的图像重绘 ControlNet 模型 [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint),模型结构采用了轻量化的设计,请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)。
73
+
74
+ - **2025年8月15日** 我们开源了 [Qwen-Image-Self-Generated-Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Qwen-Image-Self-Generated-Dataset) 数据集。这是一个使用 Qwen-Image 模型生成的图像数据集,共包含 160,000 张`1024 x 1024`图像。它包括通用、英文文本渲染和中文文本渲染子集。我们为每张图像提供了图像描述、实体和结构控制图像的标注。开发者可以使用这个数据集来训练 Qwen-Image 模型的 ControlNet 和 EliGen 等模型,我们旨在通过开源推动技术发展!
75
+
76
+ - **2025年8月13日** 我们训练并开源了 Qwen-Image 的 ControlNet 模型 [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth),模型结构采用了轻量化的设计,请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py)。
77
+
78
+ - **2025年8月12日** 我们训练并开源了 Qwen-Image 的 ControlNet 模型 [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny),模型结构采用了轻量化的设计,请参考[我们的示例代码](./examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py)。
79
+
80
+ - **2025年8月11日** 我们开源了 Qwen-Image 的蒸馏加速模型 [DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA),沿用了与 [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full) 相同的训练流程,但模型结构修改为了 LoRA,因此能够更好地与其他开源生态模型兼容。
81
+
82
+ - **2025年8月7日** 我们开源了 Qwen-Image 的实体控制 LoRA 模型 [DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)。Qwen-Image-EliGen 能够实现实体级可控的文生图。技术细节请参见[论文](https://arxiv.org/abs/2501.01097)。训练数据集:[EliGenTrainSet](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)。
83
+
84
+ - **2025年8月5日** 我们开源了 Qwen-Image 的蒸馏加速模型 [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full),实现了约 5 倍加速。
85
+
86
+ - **2025年8月4日** 🔥 Qwen-Image 开源,欢迎图像生成模型家族新成员!
87
+
88
+ - **2025年8月1日** [FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev) 开源,这是一个专注于美学摄影的文生图模型。我们第一时间提供了全方位支持,包括低显存逐层 offload、LoRA 训练、全量训练。详细信息请参考 [./examples/flux/](./examples/flux/)。
89
+
90
+ - **2025年7月28日** Wan 2.2 开源,我们第一时间提供了全方位支持,包括低显存逐层 offload、FP8 量化、序列并行、LoRA 训练、全量训练。详细信息请参考 [./examples/wanvideo/](./examples/wanvideo/)。
91
+
92
+ - **2025年7月11日** 我们提出 Nexus-Gen,一个将大语言模型(LLM)的语言推理能力与扩散模型的图像生成能力相结合的统一框架。该框架支持无缝的图像理解、生成和编辑任务。
93
+ - 论文: [Nexus-Gen: Unified Image Understanding, Generation, and Editing via Prefilled Autoregression in Shared Embedding Space](https://arxiv.org/pdf/2504.21356)
94
+ - Github 仓库: https://github.com/modelscope/Nexus-Gen
95
+ - 模型: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2), [HuggingFace](https://huggingface.co/modelscope/Nexus-GenV2)
96
+ - 训练数据集: [ModelScope Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Nexus-Gen-Training-Dataset)
97
+ - 在线体验: [ModelScope Nexus-Gen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/Nexus-Gen)
98
+
99
+ - **2025年6月15日** ModelScope 官方评测框架 [EvalScope](https://github.com/modelscope/evalscope) 现已支持文生图生成评测。请参考[最佳实践](https://evalscope.readthedocs.io/zh-cn/latest/best_practice/t2i_eval.html)指南进行尝试。
100
+
101
+ - **2025年3月25日** 我们的新开源项目 [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine) 现已开源!专注于稳定的模型部署,面向工业界,提供更好的工程支持、更高的计算性能和更稳定的功能。
102
+
103
+ - **2025年3月31日** 我们支持 InfiniteYou,一种用于 FLUX 的人脸特征保留方法。更多细节请参考 [./examples/InfiniteYou/](./examples/InfiniteYou/)。
104
+
105
+ - **2025年3月13日** 我们支持 HunyuanVideo-I2V,即腾讯开源的 HunyuanVideo 的图像到视频生成版本。更多细节请参考 [./examples/HunyuanVideo/](./examples/HunyuanVideo/)。
106
+
107
+ - **2025年2月25日** 我们支持 Wan-Video,这是阿里巴巴开源的一系列最先进的视频合成模型。详见 [./examples/wanvideo/](./examples/wanvideo/)。
108
+
109
+ - **2025年2月17日** 我们支持 [StepVideo](https://modelscope.cn/models/stepfun-ai/stepvideo-t2v/summary)!先进的视频合成模型!详见 [./examples/stepvideo](./examples/stepvideo/)。
110
+
111
+ - **2024年12月31日** 我们提出 EliGen,一种用于精确实体级别控制的文本到图像生成的新框架,并辅以修复融合管道,将其能力扩展到图像修复任务。EliGen 可以无缝集成现有的社区模型,如 IP-Adapter 和 In-Context LoRA,提升其通用性。更多详情,请见 [./examples/EntityControl](./examples/EntityControl/)。
112
+ - 论文: [EliGen: Entity-Level Controlled Image Generation with Regional Attention](https://arxiv.org/abs/2501.01097)
113
+ - 模型: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen), [HuggingFace](https://huggingface.co/modelscope/EliGen)
114
+ - 在线体验: [ModelScope EliGen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/EliGen)
115
+ - 训练数据集: [EliGen Train Set](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)
116
+
117
+ - **2024年12月19日** 我们为 HunyuanVideo 实现了高级显存管理,使得在 24GB 显存下可以生成分辨率为 129x720x1280 的视频,或在仅 6GB 显存下生成分辨率为 129x512x384 的视频。更多细节请参考 [./examples/HunyuanVideo/](./examples/HunyuanVideo/)。
118
+
119
+ - **2024年12月18日** 我们提出 ArtAug,一种通过合成-理解交互来改进文生图模型的方法。我们以 LoRA 格式为 FLUX.1-dev 训练了一个 ArtAug 增强模块。该模型将 Qwen2-VL-72B 的美学��解融入 FLUX.1-dev,从而提升了生成图像的质量。
120
+ - 论文: https://arxiv.org/abs/2412.12888
121
+ - 示例: https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/ArtAug
122
+ - 模型: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1), [HuggingFace](https://huggingface.co/ECNU-CILab/ArtAug-lora-FLUX.1dev-v1)
123
+ - 演示: [ModelScope](https://modelscope.cn/aigc/imageGeneration?tab=advanced&versionId=7228&modelType=LoRA&sdVersion=FLUX_1&modelUrl=modelscope%3A%2F%2FDiffSynth-Studio%2FArtAug-lora-FLUX.1dev-v1%3Frevision%3Dv1.0), HuggingFace (即将上线)
124
+
125
+ - **2024年10月25日** 我们提供了广泛的 FLUX ControlNet 支持。该项目支持许多不同的 ControlNet 模型,并且可以自由组合,即使它们的结构不同。此外,ControlNet 模型兼容高分辨率优化和分区控制技术,能够实现非常强大的可控图像生成。详见 [`./examples/ControlNet/`](./examples/ControlNet/)。
126
+
127
+ - **2024年10月8日** 我们发布了基于 CogVideoX-5B 和 ExVideo 的扩展 LoRA。您可以从 [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1) 或 [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-CogVideoX-LoRA-129f-v1) 下载此模型。
128
+
129
+ - **2024年8月22日** 本项目现已支持 CogVideoX-5B。详见 [此处](/examples/video_synthesis/)。我们为这个文生视频模型提供了几个有趣的功能,包括:
130
+ - 文本到视频
131
+ - 视频编辑
132
+ - 自我超分
133
+ - 视频插帧
134
+
135
+ - **2024年8月22日** 我们实现了一个有趣的画笔功能,支持所有文生图模型。现在,您可以在 AI 的辅助下使用画笔创作惊艳的图像了!
136
+ - 在我们的 [WebUI](#usage-in-webui) 中使用它。
137
+
138
+ - **2024年8月21日** DiffSynth-Studio 现已支持 FLUX。
139
+ - 启用 CFG 和高分辨率修复以提升视觉质量。详见 [此处](/examples/image_synthesis/README.md)
140
+ - LoRA、ControlNet 和其他附加模型将很快推出。
141
+
142
+ - **2024年6月21日** 我们提出 ExVideo,一种旨在增强视频生成模型能力的后训练微调技术。我们将 Stable Video Diffusion 进行了扩展,实现了长达 128 帧的长视频生成。
143
+ - [项目页面](https://ecnu-cilab.github.io/ExVideoProjectPage/)
144
+ - 源代码已在此仓库中发布。详见 [`examples/ExVideo`](./examples/ExVideo/)。
145
+ - 模型已发布于 [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1) 和 [ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1)。
146
+ - 技术报告已发布于 [arXiv](https://arxiv.org/abs/2406.14130)。
147
+ - 您可以在此 [演示](https://huggingface.co/spaces/modelscope/ExVideo-SVD-128f-v1) 中试用 ExVideo!
148
+
149
+ - **2024年6月13日** DiffSynth Studio 已迁移至 ModelScope。开发团队也从“我”转变为“我们”。当然,我仍会参与后续的开发和维护工作。
150
+
151
+ - **2024年1月29日** 我们提出 Diffutoon,这是一个出色的卡通着色解决方案。
152
+ - [项目页面](https://ecnu-cilab.github.io/DiffutoonProjectPage/)
153
+ - 源代码已在此项目中发布。
154
+ - 技术报告(IJCAI 2024)已发布于 [arXiv](https://arxiv.org/abs/2401.16224)。
155
+
156
+ - **2023年12月8日** 我们决定启动一个新项目,旨在释放扩散模型的潜力,尤其是在视频合成方面。该项目的开发工作正式开始。
157
+
158
+ - **2023年11月15日** 我们提出 FastBlend,一种强大的视频去闪烁算法。
159
+ - sd-webui 扩展已发布于 [GitHub](https://github.com/Artiprocher/sd-webui-fastblend)。
160
+ - 演示视频已在 Bilibili 上展示,包含三个任务:
161
+ - [视频去闪烁](https://www.bilibili.com/video/BV1d94y1W7PE)
162
+ - [视频插帧](https://www.bilibili.com/video/BV1Lw411m71p)
163
+ - [图像驱动的视频渲染](https://www.bilibili.com/video/BV1RB4y1Z7LF)
164
+ - 技术报告已发布于 [arXiv](https://arxiv.org/abs/2311.09265)。
165
+ - 其他用户开发的非官方 ComfyUI 扩展已发布于 [GitHub](https://github.com/AInseven/ComfyUI-fastblend)。
166
+
167
+ - **2023年10月1日** 我们发布了该项目的早期版本,名为 FastSDXL。这是构建一个扩散引擎的初步尝试。
168
+ - 源代码已发布于 [GitHub](https://github.com/Artiprocher/FastSDXL)。
169
+ - FastSDXL 包含一个可训练的 OLSS 调度器,以提高效率。
170
+ - OLSS 的原始仓库位于 [此处](https://github.com/alibaba/EasyNLP/tree/master/diffusion/olss_scheduler)。
171
+ - 技术报告(CIKM 2023)已发布于 [arXiv](https://arxiv.org/abs/2305.14677)。
172
+ - 演示视频已发布于 [Bilibili](https://www.bilibili.com/video/BV1w8411y7uj)。
173
+ - 由于 OLSS 需要额外训练,我们未在本项目中实现它。
174
+
175
+ - **2023年8月29日** 我们提出 DiffSynth,一个视频合成框架。
176
+ - [项目页面](https://ecnu-cilab.github.io/DiffSynth.github.io/)。
177
+ - 源代码已发布在 [EasyNLP](https://github.com/alibaba/EasyNLP/tree/master/diffusion/DiffSynth)。
178
+ - 技术报告(ECML PKDD 2024)已发布于 [arXiv](https://arxiv.org/abs/2308.03463)。
179
+
180
+ </details>
181
+
182
+ ## 安装
183
+
184
+ 从源码安装(推荐):
185
+
186
+ ```
187
+ git clone https://github.com/modelscope/DiffSynth-Studio.git
188
+ cd DiffSynth-Studio
189
+ pip install -e .
190
+ ```
191
+
192
+ <details>
193
+ <summary>其他安装方式</summary>
194
+
195
+ 从 pypi 安装(存在版本更新延迟,如需使用最新功能,请从源码安装)
196
+
197
+ ```
198
+ pip install diffsynth
199
+ ```
200
+
201
+ 如果在安装过程中遇到问题,可能是由上游依赖包导致的,请参考这些包的文档:
202
+
203
+ * [torch](https://pytorch.org/get-started/locally/)
204
+ * [sentencepiece](https://github.com/google/sentencepiece)
205
+ * [cmake](https://cmake.org)
206
+ * [cupy](https://docs.cupy.dev/en/stable/install.html)
207
+
208
+ </details>
209
+
210
+ ## 基础框架
211
+
212
+ DiffSynth-Studio 为主流 Diffusion 模型(包括 FLUX、Wan 等)重新设计了推理和训练流水线,能够实现高效的显存管理、灵活的模型训练。
213
+
214
+ <details>
215
+ <summary>环境变量配置</summary>
216
+
217
+ > 在进行模型推理和训练前,可通过[环境变量](/docs/zh/Pipeline_Usage/Environment_Variables.md)配置模型下载源等。
218
+ >
219
+ > 本项目默认从魔搭社区下载模型。对于非中国区域的用户,可以通过以下配置从魔搭社区的国际站下载模型:
220
+ >
221
+ > ```python
222
+ > import os
223
+ > os.environ["MODELSCOPE_DOMAIN"] = "www.modelscope.ai"
224
+ > ```
225
+ >
226
+ > 如需从其他站点下载,请修改[环境变量 DIFFSYNTH_DOWNLOAD_SOURCE](/docs/zh/Pipeline_Usage/Environment_Variables.md#diffsynth_download_source)。
227
+
228
+ </details>
229
+
230
+ ### 图像生成模型
231
+
232
+ ![Image](https://github.com/user-attachments/assets/c01258e2-f251-441a-aa1e-ebb22f02594d)
233
+
234
+ #### Z-Image:[/docs/zh/Model_Details/Z-Image.md](/docs/zh/Model_Details/Z-Image.md)
235
+
236
+ <details>
237
+
238
+ <summary>快速开始</summary>
239
+
240
+ 运行以下代码可以快速加载 [Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo) 模型并进行推理。FP8 精度量化会导致明显的图像质量劣化,因此不建议在 Z-Image Turbo 模型上开启任何量化,仅建议开启 CPU Offload,最低 8G 显存即可运行。
241
+
242
+ ```python
243
+ from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig
244
+ import torch
245
+
246
+ vram_config = {
247
+ "offload_dtype": torch.bfloat16,
248
+ "offload_device": "cpu",
249
+ "onload_dtype": torch.bfloat16,
250
+ "onload_device": "cpu",
251
+ "preparing_dtype": torch.bfloat16,
252
+ "preparing_device": "cuda",
253
+ "computation_dtype": torch.bfloat16,
254
+ "computation_device": "cuda",
255
+ }
256
+ pipe = ZImagePipeline.from_pretrained(
257
+ torch_dtype=torch.bfloat16,
258
+ device="cuda",
259
+ model_configs=[
260
+ ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors", **vram_config),
261
+ ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
262
+ ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
263
+ ],
264
+ tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
265
+ vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
266
+ )
267
+ prompt = "Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights."
268
+ image = pipe(prompt=prompt, seed=42, rand_device="cuda")
269
+ image.save("image.jpg")
270
+ ```
271
+
272
+ </details>
273
+
274
+ <details>
275
+
276
+ <summary>示例代码</summary>
277
+
278
+ Z-Image 的示例代码位于:[/examples/z_image/](/examples/z_image/)
279
+
280
+ |模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
281
+ |-|-|-|-|-|-|-|
282
+ |[Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo)|[code](/examples/z_image/model_inference/Z-Image-Turbo.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py)|
283
+
284
+ </details>
285
+
286
+ #### FLUX.2: [/docs/zh/Model_Details/FLUX2.md](/docs/zh/Model_Details/FLUX2.md)
287
+
288
+ <details>
289
+
290
+ <summary>快速开始</summary>
291
+
292
+ 运行以下代码可以快速加载 [black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 10G 显存即可运行。
293
+
294
+ ```python
295
+ from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
296
+ import torch
297
+
298
+ vram_config = {
299
+ "offload_dtype": "disk",
300
+ "offload_device": "disk",
301
+ "onload_dtype": torch.float8_e4m3fn,
302
+ "onload_device": "cpu",
303
+ "preparing_dtype": torch.float8_e4m3fn,
304
+ "preparing_device": "cuda",
305
+ "computation_dtype": torch.bfloat16,
306
+ "computation_device": "cuda",
307
+ }
308
+ pipe = Flux2ImagePipeline.from_pretrained(
309
+ torch_dtype=torch.bfloat16,
310
+ device="cuda",
311
+ model_configs=[
312
+ ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
313
+ ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="transformer/*.safetensors", **vram_config),
314
+ ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
315
+ ],
316
+ tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="tokenizer/"),
317
+ vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
318
+ )
319
+ prompt = "High resolution. A dreamy underwater portrait of a serene young woman in a flowing blue dress. Her hair floats softly around her face, strands delicately suspended in the water. Clear, shimmering light filters through, casting gentle highlights, while tiny bubbles rise around her. Her expression is calm, her features finely detailed—creating a tranquil, ethereal scene."
320
+ image = pipe(prompt, seed=42, rand_device="cuda", num_inference_steps=50)
321
+ image.save("image.jpg")
322
+ ```
323
+
324
+ </details>
325
+
326
+ <details>
327
+
328
+ <summary>示例代码</summary>
329
+
330
+ FLUX.2 的示例代码位于:[/examples/flux2/](/examples/flux2/)
331
+
332
+ |模型 ID|推理|低显存推理|LoRA 训练|LoRA 训练后验证|
333
+ |-|-|-|-|-|
334
+ |[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](/examples/flux2/model_inference/FLUX.2-dev.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|[code](/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)|
335
+
336
+ </details>
337
+
338
+ #### Qwen-Image: [/docs/zh/Model_Details/Qwen-Image.md](/docs/zh/Model_Details/Qwen-Image.md)
339
+
340
+ <details>
341
+
342
+ <summary>快速开始</summary>
343
+
344
+ 运行以下代码可以快速加载 [Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 8G 显存即可运行。
345
+
346
+ ```python
347
+ from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
348
+ import torch
349
+
350
+ vram_config = {
351
+ "offload_dtype": "disk",
352
+ "offload_device": "disk",
353
+ "onload_dtype": torch.float8_e4m3fn,
354
+ "onload_device": "cpu",
355
+ "preparing_dtype": torch.float8_e4m3fn,
356
+ "preparing_device": "cuda",
357
+ "computation_dtype": torch.bfloat16,
358
+ "computation_device": "cuda",
359
+ }
360
+ pipe = QwenImagePipeline.from_pretrained(
361
+ torch_dtype=torch.bfloat16,
362
+ device="cuda",
363
+ model_configs=[
364
+ ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
365
+ ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config),
366
+ ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
367
+ ],
368
+ tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
369
+ vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
370
+ )
371
+ prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
372
+ image = pipe(prompt, seed=0, num_inference_steps=40)
373
+ image.save("image.jpg")
374
+ ```
375
+
376
+ </details>
377
+
378
+ <details>
379
+
380
+ <summary>模型血缘</summary>
381
+
382
+ ```mermaid
383
+ graph LR;
384
+ Qwen/Qwen-Image-->Qwen/Qwen-Image-Edit;
385
+ Qwen/Qwen-Image-Edit-->Qwen/Qwen-Image-Edit-2509;
386
+ Qwen/Qwen-Image-->EliGen-Series;
387
+ EliGen-Series-->DiffSynth-Studio/Qwen-Image-EliGen;
388
+ DiffSynth-Studio/Qwen-Image-EliGen-->DiffSynth-Studio/Qwen-Image-EliGen-V2;
389
+ EliGen-Series-->DiffSynth-Studio/Qwen-Image-EliGen-Poster;
390
+ Qwen/Qwen-Image-->Distill-Series;
391
+ Distill-Series-->DiffSynth-Studio/Qwen-Image-Distill-Full;
392
+ Distill-Series-->DiffSynth-Studio/Qwen-Image-Distill-LoRA;
393
+ Qwen/Qwen-Image-->ControlNet-Series;
394
+ ControlNet-Series-->Blockwise-ControlNet-Series;
395
+ Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny;
396
+ Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth;
397
+ Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint;
398
+ ControlNet-Series-->DiffSynth-Studio/Qwen-Image-In-Context-Control-Union;
399
+ Qwen/Qwen-Image-->DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix;
400
+ ```
401
+
402
+ </details>
403
+
404
+ <details>
405
+
406
+ <summary>示例代码</summary>
407
+
408
+ Qwen-Image 的示例代码位于:[/examples/qwen_image/](/examples/qwen_image/)
409
+
410
+ |模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
411
+ |-|-|-|-|-|-|-|
412
+ |[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](/examples/qwen_image/model_inference/Qwen-Image.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image.py)|
413
+ |[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)|
414
+ |[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)|
415
+ |[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
416
+ |[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
417
+ |[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)|
418
+ |[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](/examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-Full.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py)|
419
+ |[DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA)|[code](/examples/qwen_image/model_inference/Qwen-Image-Distill-LoRA.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-LoRA.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-LoRA.py)|
420
+ |[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py)|
421
+ |[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py)|
422
+ |[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|
423
+ |[DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union)|[code](/examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py)|
424
+ |[DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py)|-|-|-|-|
425
+ |[DiffSynth-Studio/Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L)|[code](/examples/qwen_image/model_inference/Qwen-Image-i2L.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-i2L.py)|-|-|-|-|
426
+
427
+ </details>
428
+
429
+ #### FLUX.1: [/docs/zh/Model_Details/FLUX.md](/docs/zh/Model_Details/FLUX.md)
430
+
431
+ <details>
432
+
433
+ <summary>快速开始</summary>
434
+
435
+ 运行以下代码可以快速加载 [black-forest-labs/FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 8G 显存即可运行。
436
+
437
+ ```python
438
+ import torch
439
+ from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig
440
+
441
+ vram_config = {
442
+ "offload_dtype": torch.float8_e4m3fn,
443
+ "offload_device": "cpu",
444
+ "onload_dtype": torch.float8_e4m3fn,
445
+ "onload_device": "cpu",
446
+ "preparing_dtype": torch.float8_e4m3fn,
447
+ "preparing_device": "cuda",
448
+ "computation_dtype": torch.bfloat16,
449
+ "computation_device": "cuda",
450
+ }
451
+ pipe = FluxImagePipeline.from_pretrained(
452
+ torch_dtype=torch.bfloat16,
453
+ device="cuda",
454
+ model_configs=[
455
+ ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors", **vram_config),
456
+ ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
457
+ ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors", **vram_config),
458
+ ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", **vram_config),
459
+ ],
460
+ vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 1,
461
+ )
462
+ prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her."
463
+ image = pipe(prompt=prompt, seed=0)
464
+ image.save("image.jpg")
465
+ ```
466
+
467
+ </details>
468
+
469
+ <details>
470
+
471
+ <summary>模型血缘</summary>
472
+
473
+ ```mermaid
474
+ graph LR;
475
+ FLUX.1-Series-->black-forest-labs/FLUX.1-dev;
476
+ FLUX.1-Series-->black-forest-labs/FLUX.1-Krea-dev;
477
+ FLUX.1-Series-->black-forest-labs/FLUX.1-Kontext-dev;
478
+ black-forest-labs/FLUX.1-dev-->FLUX.1-dev-ControlNet-Series;
479
+ FLUX.1-dev-ControlNet-Series-->alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta;
480
+ FLUX.1-dev-ControlNet-Series-->InstantX/FLUX.1-dev-Controlnet-Union-alpha;
481
+ FLUX.1-dev-ControlNet-Series-->jasperai/Flux.1-dev-Controlnet-Upscaler;
482
+ black-forest-labs/FLUX.1-dev-->InstantX/FLUX.1-dev-IP-Adapter;
483
+ black-forest-labs/FLUX.1-dev-->ByteDance/InfiniteYou;
484
+ black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/Eligen;
485
+ black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev;
486
+ black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev;
487
+ black-forest-labs/FLUX.1-dev-->ostris/Flex.2-preview;
488
+ black-forest-labs/FLUX.1-dev-->stepfun-ai/Step1X-Edit;
489
+ Qwen/Qwen2.5-VL-7B-Instruct-->stepfun-ai/Step1X-Edit;
490
+ black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/Nexus-GenV2;
491
+ Qwen/Qwen2.5-VL-7B-Instruct-->DiffSynth-Studio/Nexus-GenV2;
492
+ ```
493
+
494
+ </details>
495
+
496
+ <details>
497
+
498
+ <summary>示例代码</summary>
499
+
500
+ FLUX.1 的示例代码位于:[/examples/flux/](/examples/flux/)
501
+
502
+ |模型 ID|额外参数|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
503
+ |-|-|-|-|-|-|-|-|
504
+ |[black-forest-labs/FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev)||[code](/examples/flux/model_inference/FLUX.1-dev.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev.py)|[code](/examples/flux/model_training/full/FLUX.1-dev.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev.py)|
505
+ |[black-forest-labs/FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev)||[code](/examples/flux/model_inference/FLUX.1-Krea-dev.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-Krea-dev.py)|[code](/examples/flux/model_training/full/FLUX.1-Krea-dev.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-Krea-dev.py)|[code](/examples/flux/model_training/lora/FLUX.1-Krea-dev.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-Krea-dev.py)|
506
+ |[black-forest-labs/FLUX.1-Kontext-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev)|`kontext_images`|[code](/examples/flux/model_inference/FLUX.1-Kontext-dev.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-Kontext-dev.py)|[code](/examples/flux/model_training/full/FLUX.1-Kontext-dev.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-Kontext-dev.py)|[code](/examples/flux/model_training/lora/FLUX.1-Kontext-dev.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-Kontext-dev.py)|
507
+ |[alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta](https://www.modelscope.cn/models/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta)|`controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|
508
+ |[InstantX/FLUX.1-dev-Controlnet-Union-alpha](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-Controlnet-Union-alpha)|`controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Union-alpha.py)|
509
+ |[jasperai/Flux.1-dev-Controlnet-Upscaler](https://www.modelscope.cn/models/jasperai/Flux.1-dev-Controlnet-Upscaler)|`controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Upscaler.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Upscaler.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Upscaler.py)|
510
+ |[InstantX/FLUX.1-dev-IP-Adapter](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-IP-Adapter)|`ipadapter_images`, `ipadapter_scale`|[code](/examples/flux/model_inference/FLUX.1-dev-IP-Adapter.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-IP-Adapter.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-IP-Adapter.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-IP-Adapter.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-IP-Adapter.py)|
511
+ |[ByteDance/InfiniteYou](https://www.modelscope.cn/models/ByteDance/InfiniteYou)|`infinityou_id_image`, `infinityou_guidance`, `controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-InfiniteYou.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-InfiniteYou.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-InfiniteYou.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-InfiniteYou.py)|
512
+ |[DiffSynth-Studio/Eligen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen)|`eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint`|[code](/examples/flux/model_inference/FLUX.1-dev-EliGen.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py)|-|-|[code](/examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-EliGen.py)|
513
+ |[DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev](https://www.modelscope.cn/models/DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev)|`lora_encoder_inputs`, `lora_encoder_scale`|[code](/examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-LoRA-Encoder.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py)|-|-|
514
+ |[DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev)||[code](/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py)|-|-|-|-|-|
515
+ |[stepfun-ai/Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit)|`step1x_reference_image`|[code](/examples/flux/model_inference/Step1X-Edit.py)|[code](/examples/flux/model_inference_low_vram/Step1X-Edit.py)|[code](/examples/flux/model_training/full/Step1X-Edit.sh)|[code](/examples/flux/model_training/validate_full/Step1X-Edit.py)|[code](/examples/flux/model_training/lora/Step1X-Edit.sh)|[code](/examples/flux/model_training/validate_lora/Step1X-Edit.py)|
516
+ |[ostris/Flex.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview)|`flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop`|[code](/examples/flux/model_inference/FLEX.2-preview.py)|[code](/examples/flux/model_inference_low_vram/FLEX.2-preview.py)|[code](/examples/flux/model_training/full/FLEX.2-preview.sh)|[code](/examples/flux/model_training/validate_full/FLEX.2-preview.py)|[code](/examples/flux/model_training/lora/FLEX.2-preview.sh)|[code](/examples/flux/model_training/validate_lora/FLEX.2-preview.py)|
517
+ |[DiffSynth-Studio/Nexus-GenV2](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2)|`nexus_gen_reference_image`|[code](/examples/flux/model_inference/Nexus-Gen-Editing.py)|[code](/examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py)|[code](/examples/flux/model_training/full/Nexus-Gen.sh)|[code](/examples/flux/model_training/validate_full/Nexus-Gen.py)|[code](/examples/flux/model_training/lora/Nexus-Gen.sh)|[code](/examples/flux/model_training/validate_lora/Nexus-Gen.py)|
518
+
519
+ </details>
520
+
521
+ ### 视频生成模型
522
+
523
+ https://github.com/user-attachments/assets/1d66ae74-3b02-40a9-acc3-ea95fc039314
524
+
525
+ #### Wan: [/docs/zh/Model_Details/Wan.md](/docs/zh/Model_Details/Wan.md)
526
+
527
+ <details>
528
+
529
+ <summary>快速开始</summary>
530
+
531
+ 运行以下代码可以快速加载 [Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 8G 显存即可运行。
532
+
533
+ ```python
534
+ import torch
535
+ from diffsynth.utils.data import save_video, VideoData
536
+ from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
537
+
538
+ vram_config = {
539
+ "offload_dtype": "disk",
540
+ "offload_device": "disk",
541
+ "onload_dtype": torch.bfloat16,
542
+ "onload_device": "cpu",
543
+ "preparing_dtype": torch.bfloat16,
544
+ "preparing_device": "cuda",
545
+ "computation_dtype": torch.bfloat16,
546
+ "computation_device": "cuda",
547
+ }
548
+ pipe = WanVideoPipeline.from_pretrained(
549
+ torch_dtype=torch.bfloat16,
550
+ device="cuda",
551
+ model_configs=[
552
+ ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", **vram_config),
553
+ ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", **vram_config),
554
+ ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", **vram_config),
555
+ ],
556
+ tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"),
557
+ vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2,
558
+ )
559
+
560
+ video = pipe(
561
+ prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
562
+ negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
563
+ seed=0, tiled=True,
564
+ )
565
+ save_video(video, "video.mp4", fps=15, quality=5)
566
+ ```
567
+
568
+ </details>
569
+
570
+ <details>
571
+
572
+ <summary>模型血缘</summary>
573
+
574
+ ```mermaid
575
+ graph LR;
576
+ Wan-Series-->Wan2.1-Series;
577
+ Wan-Series-->Wan2.2-Series;
578
+ Wan2.1-Series-->Wan-AI/Wan2.1-T2V-1.3B;
579
+ Wan2.1-Series-->Wan-AI/Wan2.1-T2V-14B;
580
+ Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-I2V-14B-480P;
581
+ Wan-AI/Wan2.1-I2V-14B-480P-->Wan-AI/Wan2.1-I2V-14B-720P;
582
+ Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-FLF2V-14B-720P;
583
+ Wan-AI/Wan2.1-T2V-1.3B-->iic/VACE-Wan2.1-1.3B-Preview;
584
+ iic/VACE-Wan2.1-1.3B-Preview-->Wan-AI/Wan2.1-VACE-1.3B;
585
+ Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-VACE-14B;
586
+ Wan-AI/Wan2.1-T2V-1.3B-->Wan2.1-Fun-1.3B-Series;
587
+ Wan2.1-Fun-1.3B-Series-->PAI/Wan2.1-Fun-1.3B-InP;
588
+ Wan2.1-Fun-1.3B-Series-->PAI/Wan2.1-Fun-1.3B-Control;
589
+ Wan-AI/Wan2.1-T2V-14B-->Wan2.1-Fun-14B-Series;
590
+ Wan2.1-Fun-14B-Series-->PAI/Wan2.1-Fun-14B-InP;
591
+ Wan2.1-Fun-14B-Series-->PAI/Wan2.1-Fun-14B-Control;
592
+ Wan-AI/Wan2.1-T2V-1.3B-->Wan2.1-Fun-V1.1-1.3B-Series;
593
+ Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-Control;
594
+ Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-InP;
595
+ Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera;
596
+ Wan-AI/Wan2.1-T2V-14B-->Wan2.1-Fun-V1.1-14B-Series;
597
+ Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-Control;
598
+ Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-InP;
599
+ Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-Control-Camera;
600
+ Wan-AI/Wan2.1-T2V-1.3B-->DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1;
601
+ Wan-AI/Wan2.1-T2V-14B-->krea/krea-realtime-video;
602
+ Wan-AI/Wan2.1-T2V-14B-->meituan-longcat/LongCat-Video;
603
+ Wan-AI/Wan2.1-I2V-14B-720P-->ByteDance/Video-As-Prompt-Wan2.1-14B;
604
+ Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.2-Animate-14B;
605
+ Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.2-S2V-14B;
606
+ Wan2.2-Series-->Wan-AI/Wan2.2-T2V-A14B;
607
+ Wan2.2-Series-->Wan-AI/Wan2.2-I2V-A14B;
608
+ Wan2.2-Series-->Wan-AI/Wan2.2-TI2V-5B;
609
+ Wan-AI/Wan2.2-T2V-A14B-->Wan2.2-Fun-Series;
610
+ Wan2.2-Fun-Series-->PAI/Wan2.2-VACE-Fun-A14B;
611
+ Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-InP;
612
+ Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-Control;
613
+ Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-Control-Camera;
614
+ ```
615
+
616
+ </details>
617
+
618
+ <details>
619
+
620
+ <summary>示例代码</summary>
621
+
622
+ Wan 的示例代码位于:[/examples/wanvideo/](/examples/wanvideo/)
623
+
624
+ |模型 ID|额外参数|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
625
+ |-|-|-|-|-|-|-|
626
+ |[Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)||[code](/examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py)|
627
+ |[Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)||[code](/examples/wanvideo/model_inference/Wan2.1-T2V-14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py)|
628
+ |[Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py)|
629
+ |[Wan-AI/Wan2.1-I2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-720P.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-720P.py)|
630
+ |[Wan-AI/Wan2.1-FLF2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-FLF2V-14B-720P.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py)|
631
+ |[iic/VACE-Wan2.1-1.3B-Preview](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B-Preview.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py)|
632
+ |[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B.py)|
633
+ |[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-VACE-14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py)|
634
+ |[PAI/Wan2.1-Fun-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py)|
635
+ |[PAI/Wan2.1-Fun-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control)|`control_video`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py)|
636
+ |[PAI/Wan2.1-Fun-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-InP.py)|
637
+ |[PAI/Wan2.1-Fun-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control)|`control_video`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-Control.py)|
638
+ |[PAI/Wan2.1-Fun-V1.1-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py)|
639
+ |[PAI/Wan2.1-Fun-V1.1-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py)|
640
+ |[PAI/Wan2.1-Fun-V1.1-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py)|
641
+ |[PAI/Wan2.1-Fun-V1.1-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py)|
642
+ |[PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|
643
+ |[PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|
644
+ |[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](/examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)|
645
+ |[krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video)||[code](/examples/wanvideo/model_inference/krea-realtime-video.py)|[code](/examples/wanvideo/model_training/full/krea-realtime-video.sh)|[code](/examples/wanvideo/model_training/validate_full/krea-realtime-video.py)|[code](/examples/wanvideo/model_training/lora/krea-realtime-video.sh)|[code](/examples/wanvideo/model_training/validate_lora/krea-realtime-video.py)|
646
+ |[meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video)|`longcat_video`|[code](/examples/wanvideo/model_inference/LongCat-Video.py)|[code](/examples/wanvideo/model_training/full/LongCat-Video.sh)|[code](/examples/wanvideo/model_training/validate_full/LongCat-Video.py)|[code](/examples/wanvideo/model_training/lora/LongCat-Video.sh)|[code](/examples/wanvideo/model_training/validate_lora/LongCat-Video.py)|
647
+ |[ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B)|`vap_video`, `vap_prompt`|[code](/examples/wanvideo/model_inference/Video-As-Prompt-Wan2.1-14B.py)|[code](/examples/wanvideo/model_training/full/Video-As-Prompt-Wan2.1-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Video-As-Prompt-Wan2.1-14B.py)|[code](/examples/wanvideo/model_training/lora/Video-As-Prompt-Wan2.1-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Video-As-Prompt-Wan2.1-14B.py)|
648
+ |[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)|
649
+ |[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)|
650
+ |[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py)|
651
+ |[Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B)|`input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video`|[code](/examples/wanvideo/model_inference/Wan2.2-Animate-14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py)|
652
+ |[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](/examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-S2V-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-S2V-14B.py)|
653
+ |[PAI/Wan2.2-VACE-Fun-A14B](https://www.modelscope.cn/models/PAI/Wan2.2-VACE-Fun-A14B)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.2-VACE-Fun-A14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-VACE-Fun-A14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-VACE-Fun-A14B.py)|
654
+ |[PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py)|
655
+ |[PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py)|
656
+ |[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)|
657
+
658
+ </details>
659
+
660
+ ## 创新成果
661
+
662
+ DiffSynth-Studio 不仅仅是一个工程化的模型框架,更是创新成果的孵化器。
663
+
664
+ <details>
665
+
666
+ <summary>AttriCtrl: 图像生成模型的属性强度控制</summary>
667
+
668
+ - 论文:[AttriCtrl: Fine-Grained Control of Aesthetic Attribute Intensity in Diffusion Models
669
+ ](https://arxiv.org/abs/2508.02151)
670
+ - 代码样例:[/examples/flux/model_inference/FLUX.1-dev-AttriCtrl.py](/examples/flux/model_inference/FLUX.1-dev-AttriCtrl.py)
671
+ - 模型:[ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev)
672
+
673
+ |brightness scale = 0.1|brightness scale = 0.3|brightness scale = 0.5|brightness scale = 0.7|brightness scale = 0.9|
674
+ |-|-|-|-|-|
675
+ |![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.1.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.3.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.5.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.7.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/AttriCtrl-FLUX.1-Dev/resolve/master/assets/brightness/value_control_0.9.jpg)|
676
+
677
+ </details>
678
+
679
+
680
+ <details>
681
+
682
+ <summary>AutoLoRA: 自动化的 LoRA 检索和融合</summary>
683
+
684
+ - 论文:[AutoLoRA: Automatic LoRA Retrieval and Fine-Grained Gated Fusion for Text-to-Image Generation
685
+ ](https://arxiv.org/abs/2508.02107)
686
+ - 代码样例:[/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py](/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py)
687
+ - 模型:[ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev)
688
+
689
+ ||[LoRA 1](https://modelscope.cn/models/cancel13/cxsk)|[LoRA 2](https://modelscope.cn/models/wy413928499/xuancai2)|[LoRA 3](https://modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1)|[LoRA 4](https://modelscope.cn/models/hongyanbujian/JPL)|
690
+ |-|-|-|-|-|
691
+ |[LoRA 1](https://modelscope.cn/models/cancel13/cxsk) |![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_0.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_1.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_3.jpg)|
692
+ |[LoRA 2](https://modelscope.cn/models/wy413928499/xuancai2) |![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_1.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_1.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_3.jpg)|
693
+ |[LoRA 3](https://modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1) |![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_2_2.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_2_3.jpg)|
694
+ |[LoRA 4](https://modelscope.cn/models/hongyanbujian/JPL) |![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_0_3.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_1_3.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_2_3.jpg)|![](https://www.modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev/resolve/master/assets/car/image_3_3.jpg)|
695
+
696
+ </details>
697
+
698
+
699
+ <details>
700
+
701
+ <summary>Nexus-Gen: 统一架构的图像理解、生成、编辑</summary>
702
+
703
+ - 详细页面:https://github.com/modelscope/Nexus-Gen
704
+ - 论文:[Nexus-Gen: Unified Image Understanding, Generation, and Editing via Prefilled Autoregression in Shared Embedding Space](https://arxiv.org/pdf/2504.21356)
705
+ - 模型:[ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2), [HuggingFace](https://huggingface.co/modelscope/Nexus-GenV2)
706
+ - 数据集:[ModelScope Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Nexus-Gen-Training-Dataset)
707
+ - 在线体验:[ModelScope Nexus-Gen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/Nexus-Gen)
708
+
709
+ ![](https://github.com/modelscope/Nexus-Gen/raw/main/assets/illustrations/gen_edit.jpg)
710
+
711
+ </details>
712
+
713
+
714
+ <details>
715
+
716
+ <summary>ArtAug: 图像生成模型的美学提升</summary>
717
+
718
+ - 详细页面:[./examples/ArtAug/](./examples/ArtAug/)
719
+ - 论文:[ArtAug: Enhancing Text-to-Image Generation through Synthesis-Understanding Interaction](https://arxiv.org/abs/2412.12888)
720
+ - 模型:[ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1), [HuggingFace](https://huggingface.co/ECNU-CILab/ArtAug-lora-FLUX.1dev-v1)
721
+ - 在线体验:[ModelScope AIGC Tab](https://www.modelscope.cn/aigc/imageGeneration?tab=advanced&versionId=7228&modelType=LoRA&sdVersion=FLUX_1&modelUrl=modelscope%3A%2F%2FDiffSynth-Studio%2FArtAug-lora-FLUX.1dev-v1%3Frevision%3Dv1.0)
722
+
723
+ |FLUX.1-dev|FLUX.1-dev + ArtAug LoRA|
724
+ |-|-|
725
+ |![image_1_base](https://github.com/user-attachments/assets/e1d5c505-b423-45fe-be01-25c2758f5417)|![image_1_enhance](https://github.com/user-attachments/assets/335908e3-d0bd-41c2-9d99-d10528a2d719)|
726
+
727
+ </details>
728
+
729
+
730
+ <details>
731
+
732
+ <summary>EliGen: 精准的图像分区控制</summary>
733
+
734
+ - 论文:[EliGen: Entity-Level Controlled Image Generation with Regional Attention](https://arxiv.org/abs/2501.01097)
735
+ - 代码样例:[/examples/flux/model_inference/FLUX.1-dev-EliGen.py](/examples/flux/model_inference/FLUX.1-dev-EliGen.py)
736
+ - 模型:[ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen), [HuggingFace](https://huggingface.co/modelscope/EliGen)
737
+ - 在线体验:[ModelScope EliGen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/EliGen)
738
+ - 数据集:[EliGen Train Set](https://www.modelscope.cn/datasets/DiffSynth-Studio/EliGenTrainSet)
739
+
740
+ |实体控制区域|生成图像|
741
+ |-|-|
742
+ |![eligen_example_2_mask_0](https://github.com/user-attachments/assets/1c6d9445-5022-4d91-ad2e-dc05321883d1)|![eligen_example_2_0](https://github.com/user-attachments/assets/86739945-cb07-4a49-b3b3-3bb65c90d14f)|
743
+
744
+ </details>
745
+
746
+
747
+ <details>
748
+
749
+ <summary>ExVideo: 视频生成模型的扩展训练</summary>
750
+
751
+ - 项目页面:[Project Page](https://ecnu-cilab.github.io/ExVideoProjectPage/)
752
+ - 论文:[ExVideo: Extending Video Diffusion Models via Parameter-Efficient Post-Tuning](https://arxiv.org/abs/2406.14130)
753
+ - 代码样例:请前往[旧版本](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3/examples/ExVideo)查看
754
+ - 模型:[ModelScope](https://modelscope.cn/models/ECNU-CILab/ExVideo-SVD-128f-v1), [HuggingFace](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1)
755
+
756
+ https://github.com/modelscope/DiffSynth-Studio/assets/35051019/d97f6aa9-8064-4b5b-9d49-ed6001bb9acc
757
+
758
+ </details>
759
+
760
+
761
+ <details>
762
+
763
+ <summary>Diffutoon: 高分辨率动漫风格视频渲染</summary>
764
+
765
+ - 项目页面:[Project Page](https://ecnu-cilab.github.io/DiffutoonProjectPage/)
766
+ - 论文:[Diffutoon: High-Resolution Editable Toon Shading via Diffusion Models](https://arxiv.org/abs/2401.16224)
767
+ - 代码样例:请前往[旧版本](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3/examples/Diffutoon)查看
768
+
769
+ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/b54c05c5-d747-4709-be5e-b39af82404dd
770
+
771
+ </details>
772
+
773
+
774
+ <details>
775
+
776
+ <summary>DiffSynth: 本项目的初代版本</summary>
777
+
778
+ - 项目页面:[Project Page](https://ecnu-cilab.github.io/DiffSynth.github.io/)
779
+ - 论文:[DiffSynth: Latent In-Iteration Deflickering for Realistic Video Synthesis](https://arxiv.org/abs/2308.03463)
780
+ - 代码样例:请前往[旧版本](https://github.com/modelscope/DiffSynth-Studio/tree/afd101f3452c9ecae0c87b79adfa2e22d65ffdc3/examples/diffsynth)查看
781
+
782
+ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-4481-b79f-0c3a7361a1ea
783
+
784
+ </details>
assets/egg.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:96647e399f3a7772d32bee1562a0c01fe8b273f6ad5fe2708cc10878972fce04
3
+ size 4810228
comp_attn_bbox_layout.png ADDED
comp_attn_trajectory.png ADDED
diffsynth/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .core import *
diffsynth/configs/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .model_configs import MODEL_CONFIGS
2
+ from .vram_management_module_maps import VRAM_MANAGEMENT_MODULE_MAPS
diffsynth/configs/model_configs.py ADDED
@@ -0,0 +1,518 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ qwen_image_series = [
2
+ {
3
+ # Example: ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors")
4
+ "model_hash": "0319a1cb19835fb510907dd3367c95ff",
5
+ "model_name": "qwen_image_dit",
6
+ "model_class": "diffsynth.models.qwen_image_dit.QwenImageDiT",
7
+ },
8
+ {
9
+ # Example: ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors")
10
+ "model_hash": "8004730443f55db63092006dd9f7110e",
11
+ "model_name": "qwen_image_text_encoder",
12
+ "model_class": "diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder",
13
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.qwen_image_text_encoder.QwenImageTextEncoderStateDictConverter",
14
+ },
15
+ {
16
+ # Example: ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors")
17
+ "model_hash": "ed4ea5824d55ec3107b09815e318123a",
18
+ "model_name": "qwen_image_vae",
19
+ "model_class": "diffsynth.models.qwen_image_vae.QwenImageVAE",
20
+ },
21
+ {
22
+ # Example: ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth", origin_file_pattern="model.safetensors")
23
+ "model_hash": "073bce9cf969e317e5662cd570c3e79c",
24
+ "model_name": "qwen_image_blockwise_controlnet",
25
+ "model_class": "diffsynth.models.qwen_image_controlnet.QwenImageBlockWiseControlNet",
26
+ },
27
+ {
28
+ # Example: ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint", origin_file_pattern="model.safetensors")
29
+ "model_hash": "a9e54e480a628f0b956a688a81c33bab",
30
+ "model_name": "qwen_image_blockwise_controlnet",
31
+ "model_class": "diffsynth.models.qwen_image_controlnet.QwenImageBlockWiseControlNet",
32
+ "extra_kwargs": {"additional_in_dim": 4},
33
+ },
34
+ {
35
+ # Example: ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="SigLIP2-G384/model.safetensors")
36
+ "model_hash": "469c78b61e3e31bc9eec0d0af3d3f2f8",
37
+ "model_name": "siglip2_image_encoder",
38
+ "model_class": "diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder",
39
+ },
40
+ {
41
+ # Example: ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="DINOv3-7B/model.safetensors")
42
+ "model_hash": "5722b5c873720009de96422993b15682",
43
+ "model_name": "dinov3_image_encoder",
44
+ "model_class": "diffsynth.models.dinov3_image_encoder.DINOv3ImageEncoder",
45
+ },
46
+ {
47
+ # Example:
48
+ "model_hash": "a166c33455cdbd89c0888a3645ca5c0f",
49
+ "model_name": "qwen_image_image2lora_coarse",
50
+ "model_class": "diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel",
51
+ },
52
+ {
53
+ # Example:
54
+ "model_hash": "a5476e691767a4da6d3a6634a10f7408",
55
+ "model_name": "qwen_image_image2lora_fine",
56
+ "model_class": "diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel",
57
+ "extra_kwargs": {"residual_length": 37*37+7, "residual_mid_dim": 64}
58
+ },
59
+ {
60
+ # Example:
61
+ "model_hash": "0aad514690602ecaff932c701cb4b0bb",
62
+ "model_name": "qwen_image_image2lora_style",
63
+ "model_class": "diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel",
64
+ "extra_kwargs": {"compress_dim": 64, "use_residual": False}
65
+ },
66
+ ]
67
+
68
+ wan_series = [
69
+ {
70
+ # Example: ModelConfig(model_id="krea/krea-realtime-video", origin_file_pattern="krea-realtime-video-14b.safetensors")
71
+ "model_hash": "5ec04e02b42d2580483ad69f4e76346a",
72
+ "model_name": "wan_video_dit",
73
+ "model_class": "diffsynth.models.wan_video_dit.WanModel",
74
+ "extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06},
75
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTStateDictConverter",
76
+ },
77
+ {
78
+ # Example: ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth")
79
+ "model_hash": "9c8818c2cbea55eca56c7b447df170da",
80
+ "model_name": "wan_video_text_encoder",
81
+ "model_class": "diffsynth.models.wan_video_text_encoder.WanTextEncoder",
82
+ },
83
+ {
84
+ # Example: ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="Wan2.1_VAE.pth")
85
+ "model_hash": "ccc42284ea13e1ad04693284c7a09be6",
86
+ "model_name": "wan_video_vae",
87
+ "model_class": "diffsynth.models.wan_video_vae.WanVideoVAE",
88
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_vae.WanVideoVAEStateDictConverter",
89
+ },
90
+ {
91
+ # Example: ModelConfig(model_id="meituan-longcat/LongCat-Video", origin_file_pattern="dit/diffusion_pytorch_model*.safetensors")
92
+ "model_hash": "8b27900f680d7251ce44e2dc8ae1ffef",
93
+ "model_name": "wan_video_dit",
94
+ "model_class": "diffsynth.models.longcat_video_dit.LongCatVideoTransformer3DModel",
95
+ },
96
+ {
97
+ # Example: ModelConfig(model_id="ByteDance/Video-As-Prompt-Wan2.1-14B", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors")
98
+ "model_hash": "5f90e66a0672219f12d9a626c8c21f61",
99
+ "model_name": "wan_video_dit",
100
+ "model_class": "diffsynth.models.wan_video_dit.WanModel",
101
+ "extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06},
102
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTFromDiffusers"
103
+ },
104
+ {
105
+ # Example: ModelConfig(model_id="ByteDance/Video-As-Prompt-Wan2.1-14B", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors")
106
+ "model_hash": "5f90e66a0672219f12d9a626c8c21f61",
107
+ "model_name": "wan_video_vap",
108
+ "model_class": "diffsynth.models.wan_video_mot.MotWanModel",
109
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_mot.WanVideoMotStateDictConverter"
110
+ },
111
+ {
112
+ # Example: ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth")
113
+ "model_hash": "5941c53e207d62f20f9025686193c40b",
114
+ "model_name": "wan_video_image_encoder",
115
+ "model_class": "diffsynth.models.wan_video_image_encoder.WanImageEncoder",
116
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_image_encoder.WanImageEncoderStateDictConverter"
117
+ },
118
+ {
119
+ # Example: ModelConfig(model_id="DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1", origin_file_pattern="model.safetensors")
120
+ "model_hash": "dbd5ec76bbf977983f972c151d545389",
121
+ "model_name": "wan_video_motion_controller",
122
+ "model_class": "diffsynth.models.wan_video_motion_controller.WanMotionControllerModel",
123
+ },
124
+ {
125
+ # Example: ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
126
+ "model_hash": "9269f8db9040a9d860eaca435be61814",
127
+ "model_name": "wan_video_dit",
128
+ "model_class": "diffsynth.models.wan_video_dit.WanModel",
129
+ "extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06}
130
+ },
131
+ {
132
+ # Example: ModelConfig(model_id="Wan-AI/Wan2.1-FLF2V-14B-720P", origin_file_pattern="diffusion_pytorch_model*.safetensors")
133
+ "model_hash": "3ef3b1f8e1dab83d5b71fd7b617f859f",
134
+ "model_name": "wan_video_dit",
135
+ "model_class": "diffsynth.models.wan_video_dit.WanModel",
136
+ "extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_image_pos_emb': True}
137
+ },
138
+ {
139
+ # Example: ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors")
140
+ "model_hash": "349723183fc063b2bfc10bb2835cf677",
141
+ "model_name": "wan_video_dit",
142
+ "model_class": "diffsynth.models.wan_video_dit.WanModel",
143
+ "extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06}
144
+ },
145
+ {
146
+ # Example: ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors")
147
+ "model_hash": "6d6ccde6845b95ad9114ab993d917893",
148
+ "model_name": "wan_video_dit",
149
+ "model_class": "diffsynth.models.wan_video_dit.WanModel",
150
+ "extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06}
151
+ },
152
+ {
153
+ # Example: ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors")
154
+ "model_hash": "efa44cddf936c70abd0ea28b6cbe946c",
155
+ "model_name": "wan_video_dit",
156
+ "model_class": "diffsynth.models.wan_video_dit.WanModel",
157
+ "extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06}
158
+ },
159
+ {
160
+ # Example: ModelConfig(model_id="PAI/Wan2.1-Fun-14B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors")
161
+ "model_hash": "6bfcfb3b342cb286ce886889d519a77e",
162
+ "model_name": "wan_video_dit",
163
+ "model_class": "diffsynth.models.wan_video_dit.WanModel",
164
+ "extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06}
165
+ },
166
+ {
167
+ # Example: ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera", origin_file_pattern="diffusion_pytorch_model*.safetensors")
168
+ "model_hash": "ac6a5aa74f4a0aab6f64eb9a72f19901",
169
+ "model_name": "wan_video_dit",
170
+ "model_class": "diffsynth.models.wan_video_dit.WanModel",
171
+ "extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 32, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06, 'has_ref_conv': False, 'add_control_adapter': True, 'in_dim_control_adapter': 24}
172
+ },
173
+ {
174
+ # Example: ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors")
175
+ "model_hash": "70ddad9d3a133785da5ea371aae09504",
176
+ "model_name": "wan_video_dit",
177
+ "model_class": "diffsynth.models.wan_video_dit.WanModel",
178
+ "extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06, 'has_ref_conv': True}
179
+ },
180
+ {
181
+ # Example: ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control-Camera", origin_file_pattern="diffusion_pytorch_model*.safetensors")
182
+ "model_hash": "b61c605c2adbd23124d152ed28e049ae",
183
+ "model_name": "wan_video_dit",
184
+ "model_class": "diffsynth.models.wan_video_dit.WanModel",
185
+ "extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 32, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_ref_conv': False, 'add_control_adapter': True, 'in_dim_control_adapter': 24}
186
+ },
187
+ {
188
+ # Example: ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors")
189
+ "model_hash": "26bde73488a92e64cc20b0a7485b9e5b",
190
+ "model_name": "wan_video_dit",
191
+ "model_class": "diffsynth.models.wan_video_dit.WanModel",
192
+ "extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_ref_conv': True}
193
+ },
194
+ {
195
+ # Example: ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
196
+ "model_hash": "aafcfd9672c3a2456dc46e1cb6e52c70",
197
+ "model_name": "wan_video_dit",
198
+ "model_class": "diffsynth.models.wan_video_dit.WanModel",
199
+ "extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06}
200
+ },
201
+ {
202
+ # Example: ModelConfig(model_id="iic/VACE-Wan2.1-1.3B-Preview", origin_file_pattern="diffusion_pytorch_model*.safetensors")
203
+ "model_hash": "a61453409b67cd3246cf0c3bebad47ba",
204
+ "model_name": "wan_video_dit",
205
+ "model_class": "diffsynth.models.wan_video_dit.WanModel",
206
+ "extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 1536, 'ffn_dim': 8960, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 12, 'num_layers': 30, 'eps': 1e-06},
207
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTStateDictConverter",
208
+ },
209
+ {
210
+ # Example: ModelConfig(model_id="iic/VACE-Wan2.1-1.3B-Preview", origin_file_pattern="diffusion_pytorch_model*.safetensors")
211
+ "model_hash": "a61453409b67cd3246cf0c3bebad47ba",
212
+ "model_name": "wan_video_vace",
213
+ "model_class": "diffsynth.models.wan_video_vace.VaceWanModel",
214
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_vace.VaceWanModelDictConverter"
215
+ },
216
+ {
217
+ # Example: ModelConfig(model_id="Wan-AI/Wan2.1-VACE-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
218
+ "model_hash": "7a513e1f257a861512b1afd387a8ecd9",
219
+ "model_name": "wan_video_dit",
220
+ "model_class": "diffsynth.models.wan_video_dit.WanModel",
221
+ "extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 16, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06},
222
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTStateDictConverter",
223
+ },
224
+ {
225
+ # Example: ModelConfig(model_id="Wan-AI/Wan2.1-VACE-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
226
+ "model_hash": "7a513e1f257a861512b1afd387a8ecd9",
227
+ "model_name": "wan_video_vace",
228
+ "model_class": "diffsynth.models.wan_video_vace.VaceWanModel",
229
+ "extra_kwargs": {'vace_layers': (0, 5, 10, 15, 20, 25, 30, 35), 'vace_in_dim': 96, 'patch_size': (1, 2, 2), 'has_image_input': False, 'dim': 5120, 'num_heads': 40, 'ffn_dim': 13824, 'eps': 1e-06},
230
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_vace.VaceWanModelDictConverter"
231
+ },
232
+ {
233
+ # Example: ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
234
+ "model_hash": "31fa352acb8a1b1d33cd8764273d80a2",
235
+ "model_name": "wan_video_dit",
236
+ "model_class": "diffsynth.models.wan_video_dit.WanModel",
237
+ "extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06},
238
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_dit.WanVideoDiTStateDictConverter"
239
+ },
240
+ {
241
+ # Example: ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
242
+ "model_hash": "31fa352acb8a1b1d33cd8764273d80a2",
243
+ "model_name": "wan_video_animate_adapter",
244
+ "model_class": "diffsynth.models.wan_video_animate_adapter.WanAnimateAdapter",
245
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_animate_adapter.WanAnimateAdapterStateDictConverter"
246
+ },
247
+ {
248
+ # Example: ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control-Camera", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors")
249
+ "model_hash": "47dbeab5e560db3180adf51dc0232fb1",
250
+ "model_name": "wan_video_dit",
251
+ "model_class": "diffsynth.models.wan_video_dit.WanModel",
252
+ "extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_ref_conv': False, 'add_control_adapter': True, 'in_dim_control_adapter': 24, 'require_clip_embedding': False}
253
+ },
254
+ {
255
+ # Example: ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors")
256
+ "model_hash": "2267d489f0ceb9f21836532952852ee5",
257
+ "model_name": "wan_video_dit",
258
+ "model_class": "diffsynth.models.wan_video_dit.WanModel",
259
+ "extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 52, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'has_ref_conv': True, 'require_clip_embedding': False},
260
+ },
261
+ {
262
+ # Example: ModelConfig(model_id="Wan-AI/Wan2.2-I2V-A14B", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors")
263
+ "model_hash": "5b013604280dd715f8457c6ed6d6a626",
264
+ "model_name": "wan_video_dit",
265
+ "model_class": "diffsynth.models.wan_video_dit.WanModel",
266
+ "extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'require_clip_embedding': False}
267
+ },
268
+ {
269
+ # Example: ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
270
+ "model_hash": "966cffdcc52f9c46c391768b27637614",
271
+ "model_name": "wan_video_dit",
272
+ "model_class": "diffsynth.models.wan_video_dit_s2v.WanS2VModel",
273
+ "extra_kwargs": {'dim': 5120, 'in_dim': 16, 'ffn_dim': 13824, 'out_dim': 16, 'text_dim': 4096, 'freq_dim': 256, 'eps': 1e-06, 'patch_size': (1, 2, 2), 'num_heads': 40, 'num_layers': 40, 'cond_dim': 16, 'audio_dim': 1024, 'num_audio_token': 4}
274
+ },
275
+ {
276
+ # Example: ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
277
+ "model_hash": "1f5ab7703c6fc803fdded85ff040c316",
278
+ "model_name": "wan_video_dit",
279
+ "model_class": "diffsynth.models.wan_video_dit.WanModel",
280
+ "extra_kwargs": {'has_image_input': False, 'patch_size': [1, 2, 2], 'in_dim': 48, 'dim': 3072, 'ffn_dim': 14336, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 48, 'num_heads': 24, 'num_layers': 30, 'eps': 1e-06, 'seperated_timestep': True, 'require_clip_embedding': False, 'require_vae_embedding': False, 'fuse_vae_embedding_in_latents': True}
281
+ },
282
+ {
283
+ # Example: ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="Wan2.2_VAE.pth")
284
+ "model_hash": "e1de6c02cdac79f8b739f4d3698cd216",
285
+ "model_name": "wan_video_vae",
286
+ "model_class": "diffsynth.models.wan_video_vae.WanVideoVAE38",
287
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.wan_video_vae.WanVideoVAEStateDictConverter",
288
+ },
289
+ {
290
+ # Example: ModelConfig(model_id="Wan-AI/Wan2.2-S2V-14B", origin_file_pattern="wav2vec2-large-xlsr-53-english/model.safetensors")
291
+ "model_hash": "06be60f3a4526586d8431cd038a71486",
292
+ "model_name": "wans2v_audio_encoder",
293
+ "model_class": "diffsynth.models.wav2vec.WanS2VAudioEncoder",
294
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.wans2v_audio_encoder.WanS2VAudioEncoderStateDictConverter",
295
+ },
296
+ ]
297
+
298
+ flux_series = [
299
+ {
300
+ # Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors")
301
+ "model_hash": "a29710fea6dddb0314663ee823598e50",
302
+ "model_name": "flux_dit",
303
+ "model_class": "diffsynth.models.flux_dit.FluxDiT",
304
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
305
+ },
306
+ {
307
+ # Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors")
308
+ "model_hash": "94eefa3dac9cec93cb1ebaf1747d7b78",
309
+ "model_name": "flux_text_encoder_clip",
310
+ "model_class": "diffsynth.models.flux_text_encoder_clip.FluxTextEncoderClip",
311
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_text_encoder_clip.FluxTextEncoderClipStateDictConverter",
312
+ },
313
+ {
314
+ # Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors")
315
+ "model_hash": "22540b49eaedbc2f2784b2091a234c7c",
316
+ "model_name": "flux_text_encoder_t5",
317
+ "model_class": "diffsynth.models.flux_text_encoder_t5.FluxTextEncoderT5",
318
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_text_encoder_t5.FluxTextEncoderT5StateDictConverter",
319
+ },
320
+ {
321
+ # Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors")
322
+ "model_hash": "21ea55f476dfc4fd135587abb59dfe5d",
323
+ "model_name": "flux_vae_encoder",
324
+ "model_class": "diffsynth.models.flux_vae.FluxVAEEncoder",
325
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEEncoderStateDictConverter",
326
+ },
327
+ {
328
+ # Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors")
329
+ "model_hash": "21ea55f476dfc4fd135587abb59dfe5d",
330
+ "model_name": "flux_vae_decoder",
331
+ "model_class": "diffsynth.models.flux_vae.FluxVAEDecoder",
332
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEDecoderStateDictConverter",
333
+ },
334
+ {
335
+ # Example: ModelConfig(model_id="ostris/Flex.2-preview", origin_file_pattern="Flex.2-preview.safetensors")
336
+ "model_hash": "d02f41c13549fa5093d3521f62a5570a",
337
+ "model_name": "flux_dit",
338
+ "model_class": "diffsynth.models.flux_dit.FluxDiT",
339
+ "extra_kwargs": {'input_dim': 196, 'num_blocks': 8},
340
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
341
+ },
342
+ {
343
+ # Example: ModelConfig(model_id="DiffSynth-Studio/AttriCtrl-FLUX.1-Dev", origin_file_pattern="models/brightness.safetensors")
344
+ "model_hash": "0629116fce1472503a66992f96f3eb1a",
345
+ "model_name": "flux_value_controller",
346
+ "model_class": "diffsynth.models.flux_value_control.SingleValueEncoder",
347
+ },
348
+ {
349
+ # Example: ModelConfig(model_id="alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", origin_file_pattern="diffusion_pytorch_model.safetensors")
350
+ "model_hash": "52357cb26250681367488a8954c271e8",
351
+ "model_name": "flux_controlnet",
352
+ "model_class": "diffsynth.models.flux_controlnet.FluxControlNet",
353
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter",
354
+ "extra_kwargs": {"num_joint_blocks": 6, "num_single_blocks": 0, "additional_input_dim": 4},
355
+ },
356
+ {
357
+ # Example: ModelConfig(model_id="InstantX/FLUX.1-dev-Controlnet-Union-alpha", origin_file_pattern="diffusion_pytorch_model.safetensors")
358
+ "model_hash": "78d18b9101345ff695f312e7e62538c0",
359
+ "model_name": "flux_controlnet",
360
+ "model_class": "diffsynth.models.flux_controlnet.FluxControlNet",
361
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter",
362
+ "extra_kwargs": {"num_mode": 10, "mode_dict": {"canny": 0, "tile": 1, "depth": 2, "blur": 3, "pose": 4, "gray": 5, "lq": 6}},
363
+ },
364
+ {
365
+ # Example: ModelConfig(model_id="jasperai/Flux.1-dev-Controlnet-Upscaler", origin_file_pattern="diffusion_pytorch_model.safetensors")
366
+ "model_hash": "b001c89139b5f053c715fe772362dd2a",
367
+ "model_name": "flux_controlnet",
368
+ "model_class": "diffsynth.models.flux_controlnet.FluxControlNet",
369
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter",
370
+ "extra_kwargs": {"num_single_blocks": 0},
371
+ },
372
+ {
373
+ # Example: ModelConfig(model_id="ByteDance/InfiniteYou", origin_file_pattern="infu_flux_v1.0/aes_stage2/image_proj_model.bin")
374
+ "model_hash": "c07c0f04f5ff55e86b4e937c7a40d481",
375
+ "model_name": "infiniteyou_image_projector",
376
+ "model_class": "diffsynth.models.flux_infiniteyou.InfiniteYouImageProjector",
377
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_infiniteyou.FluxInfiniteYouImageProjectorStateDictConverter",
378
+ },
379
+ {
380
+ # Example: ModelConfig(model_id="ByteDance/InfiniteYou", origin_file_pattern="infu_flux_v1.0/aes_stage2/InfuseNetModel/*.safetensors")
381
+ "model_hash": "7f9583eb8ba86642abb9a21a4b2c9e16",
382
+ "model_name": "flux_controlnet",
383
+ "model_class": "diffsynth.models.flux_controlnet.FluxControlNet",
384
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter",
385
+ "extra_kwargs": {"num_joint_blocks": 4, "num_single_blocks": 10},
386
+ },
387
+ {
388
+ # Example: ModelConfig(model_id="DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev", origin_file_pattern="model.safetensors")
389
+ "model_hash": "77c2e4dd2440269eb33bfaa0d004f6ab",
390
+ "model_name": "flux_lora_encoder",
391
+ "model_class": "diffsynth.models.flux_lora_encoder.FluxLoRAEncoder",
392
+ },
393
+ {
394
+ # Example: ModelConfig(model_id="DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev", origin_file_pattern="model.safetensors")
395
+ "model_hash": "30143afb2dea73d1ac580e0787628f8c",
396
+ "model_name": "flux_lora_patcher",
397
+ "model_class": "diffsynth.models.flux_lora_patcher.FluxLoraPatcher",
398
+ },
399
+ {
400
+ # Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="model*.safetensors")
401
+ "model_hash": "2bd19e845116e4f875a0a048e27fc219",
402
+ "model_name": "nexus_gen_llm",
403
+ "model_class": "diffsynth.models.nexus_gen.NexusGenAutoregressiveModel",
404
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.nexus_gen.NexusGenAutoregressiveModelStateDictConverter",
405
+ },
406
+ {
407
+ # Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="edit_decoder.bin")
408
+ "model_hash": "63c969fd37cce769a90aa781fbff5f81",
409
+ "model_name": "nexus_gen_editing_adapter",
410
+ "model_class": "diffsynth.models.nexus_gen_projector.NexusGenImageEmbeddingMerger",
411
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.nexus_gen_projector.NexusGenMergerStateDictConverter",
412
+ },
413
+ {
414
+ # Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="edit_decoder.bin")
415
+ "model_hash": "63c969fd37cce769a90aa781fbff5f81",
416
+ "model_name": "flux_dit",
417
+ "model_class": "diffsynth.models.flux_dit.FluxDiT",
418
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
419
+ },
420
+ {
421
+ # Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="generation_decoder.bin")
422
+ "model_hash": "3e6c61b0f9471135fc9c6d6a98e98b6d",
423
+ "model_name": "nexus_gen_generation_adapter",
424
+ "model_class": "diffsynth.models.nexus_gen_projector.NexusGenAdapter",
425
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.nexus_gen_projector.NexusGenAdapterStateDictConverter",
426
+ },
427
+ {
428
+ # Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="generation_decoder.bin")
429
+ "model_hash": "3e6c61b0f9471135fc9c6d6a98e98b6d",
430
+ "model_name": "flux_dit",
431
+ "model_class": "diffsynth.models.flux_dit.FluxDiT",
432
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
433
+ },
434
+ {
435
+ # Example: ModelConfig(model_id="InstantX/FLUX.1-dev-IP-Adapter", origin_file_pattern="ip-adapter.bin")
436
+ "model_hash": "4daaa66cc656a8fe369908693dad0a35",
437
+ "model_name": "flux_ipadapter",
438
+ "model_class": "diffsynth.models.flux_ipadapter.FluxIpAdapter",
439
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_ipadapter.FluxIpAdapterStateDictConverter",
440
+ },
441
+ {
442
+ # Example: ModelConfig(model_id="google/siglip-so400m-patch14-384", origin_file_pattern="model.safetensors")
443
+ "model_hash": "04d8c1e20a1f1b25f7434f111992a33f",
444
+ "model_name": "siglip_vision_model",
445
+ "model_class": "diffsynth.models.flux_ipadapter.SiglipVisionModelSO400M",
446
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_ipadapter.SiglipStateDictConverter",
447
+ },
448
+ {
449
+ # Example: ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="step1x-edit-i1258.safetensors"),
450
+ "model_hash": "d30fb9e02b1dbf4e509142f05cf7dd50",
451
+ "model_name": "step1x_connector",
452
+ "model_class": "diffsynth.models.step1x_connector.Qwen2Connector",
453
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.step1x_connector.Qwen2ConnectorStateDictConverter",
454
+ },
455
+ {
456
+ # Example: ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="step1x-edit-i1258.safetensors"),
457
+ "model_hash": "d30fb9e02b1dbf4e509142f05cf7dd50",
458
+ "model_name": "flux_dit",
459
+ "model_class": "diffsynth.models.flux_dit.FluxDiT",
460
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
461
+ "extra_kwargs": {"disable_guidance_embedder": True},
462
+ },
463
+ ]
464
+
465
+ flux2_series = [
466
+ {
467
+ # Example: ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="text_encoder/*.safetensors")
468
+ "model_hash": "28fca3d8e5bf2a2d1271748a773f6757",
469
+ "model_name": "flux2_text_encoder",
470
+ "model_class": "diffsynth.models.flux2_text_encoder.Flux2TextEncoder",
471
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.flux2_text_encoder.Flux2TextEncoderStateDictConverter",
472
+ },
473
+ {
474
+ # Example: ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="transformer/*.safetensors")
475
+ "model_hash": "d38e1d5c5aec3b0a11e79327ac6e3b0f",
476
+ "model_name": "flux2_dit",
477
+ "model_class": "diffsynth.models.flux2_dit.Flux2DiT",
478
+ },
479
+ {
480
+ # Example: ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="vae/diffusion_pytorch_model.safetensors")
481
+ "model_hash": "c54288e3ee12ca215898840682337b95",
482
+ "model_name": "flux2_vae",
483
+ "model_class": "diffsynth.models.flux2_vae.Flux2VAE",
484
+ },
485
+ ]
486
+
487
+ z_image_series = [
488
+ {
489
+ # Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors")
490
+ "model_hash": "fc3a8a1247fe185ce116ccbe0e426c28",
491
+ "model_name": "z_image_dit",
492
+ "model_class": "diffsynth.models.z_image_dit.ZImageDiT",
493
+ },
494
+ {
495
+ # Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors")
496
+ "model_hash": "0f050f62a88876fea6eae0a18dac5a2e",
497
+ "model_name": "z_image_text_encoder",
498
+ "model_class": "diffsynth.models.z_image_text_encoder.ZImageTextEncoder",
499
+ },
500
+ {
501
+ # Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/vae/diffusion_pytorch_model.safetensors")
502
+ "model_hash": "1aafa3cc91716fb6b300cc1cd51b85a3",
503
+ "model_name": "flux_vae_encoder",
504
+ "model_class": "diffsynth.models.flux_vae.FluxVAEEncoder",
505
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEEncoderStateDictConverterDiffusers",
506
+ "extra_kwargs": {"use_conv_attention": False},
507
+ },
508
+ {
509
+ # Example: ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/vae/diffusion_pytorch_model.safetensors")
510
+ "model_hash": "1aafa3cc91716fb6b300cc1cd51b85a3",
511
+ "model_name": "flux_vae_decoder",
512
+ "model_class": "diffsynth.models.flux_vae.FluxVAEDecoder",
513
+ "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_vae.FluxVAEDecoderStateDictConverterDiffusers",
514
+ "extra_kwargs": {"use_conv_attention": False},
515
+ },
516
+ ]
517
+
518
+ MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series
diffsynth/configs/vram_management_module_maps.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ flux_general_vram_config = {
2
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
3
+ "torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
4
+ "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
5
+ "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
6
+ "torch.nn.GroupNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
7
+ "diffsynth.models.general_modules.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
8
+ "diffsynth.models.flux_lora_encoder.LoRALayerBlock": "diffsynth.core.vram.layers.AutoWrappedModule",
9
+ "diffsynth.models.flux_lora_patcher.LoraMerger": "diffsynth.core.vram.layers.AutoWrappedModule",
10
+ }
11
+
12
+ VRAM_MANAGEMENT_MODULE_MAPS = {
13
+ "diffsynth.models.qwen_image_dit.QwenImageDiT": {
14
+ "diffsynth.models.qwen_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
15
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
16
+ },
17
+ "diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder": {
18
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
19
+ "torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
20
+ "transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLRotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
21
+ "transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
22
+ "transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VisionPatchEmbed": "diffsynth.core.vram.layers.AutoWrappedModule",
23
+ "transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VisionRotaryEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
24
+ },
25
+ "diffsynth.models.qwen_image_vae.QwenImageVAE": {
26
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
27
+ "torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
28
+ "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
29
+ "diffsynth.models.qwen_image_vae.QwenImageRMS_norm": "diffsynth.core.vram.layers.AutoWrappedModule",
30
+ },
31
+ "diffsynth.models.qwen_image_controlnet.BlockWiseControlBlock": {
32
+ "diffsynth.models.qwen_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
33
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
34
+ },
35
+ "diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder": {
36
+ "transformers.models.siglip.modeling_siglip.SiglipVisionEmbeddings": "diffsynth.core.vram.layers.AutoWrappedModule",
37
+ "transformers.models.siglip.modeling_siglip.SiglipMultiheadAttentionPoolingHead": "diffsynth.core.vram.layers.AutoWrappedModule",
38
+ "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
39
+ "torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
40
+ "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
41
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
42
+ },
43
+ "diffsynth.models.dinov3_image_encoder.DINOv3ImageEncoder": {
44
+ "transformers.models.dinov3_vit.modeling_dinov3_vit.DINOv3ViTLayerScale": "diffsynth.core.vram.layers.AutoWrappedModule",
45
+ "transformers.models.dinov3_vit.modeling_dinov3_vit.DINOv3ViTRopePositionEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
46
+ "transformers.models.dinov3_vit.modeling_dinov3_vit.DINOv3ViTEmbeddings": "diffsynth.core.vram.layers.AutoWrappedModule",
47
+ "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
48
+ "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
49
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
50
+ },
51
+ "diffsynth.models.qwen_image_image2lora.QwenImageImage2LoRAModel": {
52
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
53
+ },
54
+ "diffsynth.models.wan_video_animate_adapter.WanAnimateAdapter": {
55
+ "diffsynth.models.wan_video_animate_adapter.FaceEncoder": "diffsynth.core.vram.layers.AutoWrappedModule",
56
+ "diffsynth.models.wan_video_animate_adapter.EqualLinear": "diffsynth.core.vram.layers.AutoWrappedModule",
57
+ "diffsynth.models.wan_video_animate_adapter.ConvLayer": "diffsynth.core.vram.layers.AutoWrappedModule",
58
+ "diffsynth.models.wan_video_animate_adapter.FusedLeakyReLU": "diffsynth.core.vram.layers.AutoWrappedModule",
59
+ "diffsynth.models.wan_video_animate_adapter.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
60
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
61
+ "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
62
+ "torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule",
63
+ "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
64
+ "torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
65
+ },
66
+ "diffsynth.models.wan_video_dit_s2v.WanS2VModel": {
67
+ "diffsynth.models.wan_video_dit.Head": "diffsynth.core.vram.layers.AutoWrappedModule",
68
+ "diffsynth.models.wan_video_dit_s2v.WanS2VDiTBlock": "diffsynth.core.vram.layers.AutoWrappedModule",
69
+ "diffsynth.models.wan_video_dit_s2v.CausalAudioEncoder": "diffsynth.core.vram.layers.AutoWrappedModule",
70
+ "torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
71
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
72
+ "torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
73
+ "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
74
+ "diffsynth.models.wan_video_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
75
+ "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
76
+ },
77
+ "diffsynth.models.wan_video_dit.WanModel": {
78
+ "diffsynth.models.wan_video_dit.MLP": "diffsynth.core.vram.layers.AutoWrappedModule",
79
+ "diffsynth.models.wan_video_dit.DiTBlock": "diffsynth.core.vram.layers.AutoWrappedNonRecurseModule",
80
+ "diffsynth.models.wan_video_dit.Head": "diffsynth.core.vram.layers.AutoWrappedModule",
81
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
82
+ "torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
83
+ "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
84
+ "diffsynth.models.wan_video_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
85
+ "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
86
+ },
87
+ "diffsynth.models.wan_video_image_encoder.WanImageEncoder": {
88
+ "diffsynth.models.wan_video_image_encoder.VisionTransformer": "diffsynth.core.vram.layers.AutoWrappedModule",
89
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
90
+ "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
91
+ "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
92
+ },
93
+ "diffsynth.models.wan_video_mot.MotWanModel": {
94
+ "diffsynth.models.wan_video_mot.MotWanAttentionBlock": "diffsynth.core.vram.layers.AutoWrappedModule",
95
+ "torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
96
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
97
+ "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
98
+ },
99
+ "diffsynth.models.wan_video_motion_controller.WanMotionControllerModel": {
100
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
101
+ },
102
+ "diffsynth.models.wan_video_text_encoder.WanTextEncoder": {
103
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
104
+ "torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
105
+ "diffsynth.models.wan_video_text_encoder.T5RelativeEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
106
+ "diffsynth.models.wan_video_text_encoder.T5LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
107
+ },
108
+ "diffsynth.models.wan_video_vace.VaceWanModel": {
109
+ "diffsynth.models.wan_video_dit.DiTBlock": "diffsynth.core.vram.layers.AutoWrappedModule",
110
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
111
+ "torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
112
+ "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
113
+ "diffsynth.models.wan_video_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
114
+ },
115
+ "diffsynth.models.wan_video_vae.WanVideoVAE": {
116
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
117
+ "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
118
+ "diffsynth.models.wan_video_vae.RMS_norm": "diffsynth.core.vram.layers.AutoWrappedModule",
119
+ "diffsynth.models.wan_video_vae.CausalConv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
120
+ "diffsynth.models.wan_video_vae.Upsample": "diffsynth.core.vram.layers.AutoWrappedModule",
121
+ "torch.nn.SiLU": "diffsynth.core.vram.layers.AutoWrappedModule",
122
+ "torch.nn.Dropout": "diffsynth.core.vram.layers.AutoWrappedModule",
123
+ },
124
+ "diffsynth.models.wan_video_vae.WanVideoVAE38": {
125
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
126
+ "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
127
+ "diffsynth.models.wan_video_vae.RMS_norm": "diffsynth.core.vram.layers.AutoWrappedModule",
128
+ "diffsynth.models.wan_video_vae.CausalConv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
129
+ "diffsynth.models.wan_video_vae.Upsample": "diffsynth.core.vram.layers.AutoWrappedModule",
130
+ "torch.nn.SiLU": "diffsynth.core.vram.layers.AutoWrappedModule",
131
+ "torch.nn.Dropout": "diffsynth.core.vram.layers.AutoWrappedModule",
132
+ },
133
+ "diffsynth.models.wav2vec.WanS2VAudioEncoder": {
134
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
135
+ "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
136
+ "torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule",
137
+ },
138
+ "diffsynth.models.longcat_video_dit.LongCatVideoTransformer3DModel": {
139
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
140
+ "torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
141
+ "diffsynth.models.longcat_video_dit.RMSNorm_FP32": "diffsynth.core.vram.layers.AutoWrappedModule",
142
+ "diffsynth.models.longcat_video_dit.LayerNorm_FP32": "diffsynth.core.vram.layers.AutoWrappedModule",
143
+ },
144
+ "diffsynth.models.flux_dit.FluxDiT": {
145
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
146
+ "diffsynth.models.flux_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
147
+ },
148
+ "diffsynth.models.flux_text_encoder_clip.FluxTextEncoderClip": flux_general_vram_config,
149
+ "diffsynth.models.flux_vae.FluxVAEEncoder": flux_general_vram_config,
150
+ "diffsynth.models.flux_vae.FluxVAEDecoder": flux_general_vram_config,
151
+ "diffsynth.models.flux_controlnet.FluxControlNet": flux_general_vram_config,
152
+ "diffsynth.models.flux_infiniteyou.InfiniteYouImageProjector": flux_general_vram_config,
153
+ "diffsynth.models.flux_ipadapter.FluxIpAdapter": flux_general_vram_config,
154
+ "diffsynth.models.flux_lora_patcher.FluxLoraPatcher": flux_general_vram_config,
155
+ "diffsynth.models.step1x_connector.Qwen2Connector": flux_general_vram_config,
156
+ "diffsynth.models.flux_lora_encoder.FluxLoRAEncoder": flux_general_vram_config,
157
+ "diffsynth.models.flux_text_encoder_t5.FluxTextEncoderT5": {
158
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
159
+ "torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
160
+ "transformers.models.t5.modeling_t5.T5LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
161
+ "transformers.models.t5.modeling_t5.T5DenseActDense": "diffsynth.core.vram.layers.AutoWrappedModule",
162
+ "transformers.models.t5.modeling_t5.T5DenseGatedActDense": "diffsynth.core.vram.layers.AutoWrappedModule",
163
+ },
164
+ "diffsynth.models.flux_ipadapter.SiglipVisionModelSO400M": {
165
+ "transformers.models.siglip.modeling_siglip.SiglipVisionEmbeddings": "diffsynth.core.vram.layers.AutoWrappedModule",
166
+ "transformers.models.siglip.modeling_siglip.SiglipEncoder": "diffsynth.core.vram.layers.AutoWrappedModule",
167
+ "transformers.models.siglip.modeling_siglip.SiglipMultiheadAttentionPoolingHead": "diffsynth.core.vram.layers.AutoWrappedModule",
168
+ "torch.nn.MultiheadAttention": "diffsynth.core.vram.layers.AutoWrappedModule",
169
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
170
+ "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
171
+ },
172
+ "diffsynth.models.flux2_dit.Flux2DiT": {
173
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
174
+ "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
175
+ "torch.nn.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
176
+ },
177
+ "diffsynth.models.flux2_text_encoder.Flux2TextEncoder": {
178
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
179
+ "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
180
+ "torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
181
+ "transformers.models.mistral.modeling_mistral.MistralRMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
182
+ },
183
+ "diffsynth.models.flux2_vae.Flux2VAE": {
184
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
185
+ "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
186
+ "torch.nn.GroupNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
187
+ },
188
+ "diffsynth.models.z_image_text_encoder.ZImageTextEncoder": {
189
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
190
+ "transformers.models.qwen3.modeling_qwen3.Qwen3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
191
+ "torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
192
+ },
193
+ "diffsynth.models.z_image_dit.ZImageDiT": {
194
+ "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
195
+ "diffsynth.models.z_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
196
+ },
197
+ }
diffsynth/core/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .attention import *
2
+ from .data import *
3
+ from .gradient import *
4
+ from .loader import *
5
+ from .vram import *
diffsynth/core/attention/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .attention import attention_forward
diffsynth/core/attention/attention.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, os
2
+ from einops import rearrange
3
+
4
+
5
+ try:
6
+ import flash_attn_interface
7
+ FLASH_ATTN_3_AVAILABLE = True
8
+ except ModuleNotFoundError:
9
+ FLASH_ATTN_3_AVAILABLE = False
10
+
11
+ try:
12
+ import flash_attn
13
+ FLASH_ATTN_2_AVAILABLE = True
14
+ except ModuleNotFoundError:
15
+ FLASH_ATTN_2_AVAILABLE = False
16
+
17
+ try:
18
+ from sageattention import sageattn
19
+ SAGE_ATTN_AVAILABLE = True
20
+ except ModuleNotFoundError:
21
+ SAGE_ATTN_AVAILABLE = False
22
+
23
+ try:
24
+ import xformers.ops as xops
25
+ XFORMERS_AVAILABLE = True
26
+ except ModuleNotFoundError:
27
+ XFORMERS_AVAILABLE = False
28
+
29
+
30
+ def initialize_attention_priority():
31
+ if os.environ.get('DIFFSYNTH_ATTENTION_IMPLEMENTATION') is not None:
32
+ return os.environ.get('DIFFSYNTH_ATTENTION_IMPLEMENTATION').lower()
33
+ elif FLASH_ATTN_3_AVAILABLE:
34
+ return "flash_attention_3"
35
+ elif FLASH_ATTN_2_AVAILABLE:
36
+ return "flash_attention_2"
37
+ elif SAGE_ATTN_AVAILABLE:
38
+ return "sage_attention"
39
+ elif XFORMERS_AVAILABLE:
40
+ return "xformers"
41
+ else:
42
+ return "torch"
43
+
44
+
45
+ ATTENTION_IMPLEMENTATION = initialize_attention_priority()
46
+
47
+
48
+ def rearrange_qkv(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", required_in_pattern="b n s d", dims=None):
49
+ dims = {} if dims is None else dims
50
+ if q_pattern != required_in_pattern:
51
+ q = rearrange(q, f"{q_pattern} -> {required_in_pattern}", **dims)
52
+ if k_pattern != required_in_pattern:
53
+ k = rearrange(k, f"{k_pattern} -> {required_in_pattern}", **dims)
54
+ if v_pattern != required_in_pattern:
55
+ v = rearrange(v, f"{q_pattern} -> {required_in_pattern}", **dims)
56
+ return q, k, v
57
+
58
+
59
+ def rearrange_out(out: torch.Tensor, out_pattern="b n s d", required_out_pattern="b n s d", dims=None):
60
+ dims = {} if dims is None else dims
61
+ if out_pattern != required_out_pattern:
62
+ out = rearrange(out, f"{required_out_pattern} -> {out_pattern}", **dims)
63
+ return out
64
+
65
+
66
+ def torch_sdpa(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, attn_mask=None, scale=None):
67
+ required_in_pattern, required_out_pattern= "b n s d", "b n s d"
68
+ q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
69
+ out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask, scale=scale)
70
+ out = rearrange_out(out, out_pattern, required_out_pattern, dims)
71
+ return out
72
+
73
+
74
+ def flash_attention_3(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None):
75
+ required_in_pattern, required_out_pattern= "b s n d", "b s n d"
76
+ q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
77
+ out = flash_attn_interface.flash_attn_func(q, k, v, softmax_scale=scale)
78
+ if isinstance(out, tuple):
79
+ out = out[0]
80
+ out = rearrange_out(out, out_pattern, required_out_pattern, dims)
81
+ return out
82
+
83
+
84
+ def flash_attention_2(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None):
85
+ required_in_pattern, required_out_pattern= "b s n d", "b s n d"
86
+ q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
87
+ out = flash_attn.flash_attn_func(q, k, v, softmax_scale=scale)
88
+ out = rearrange_out(out, out_pattern, required_out_pattern, dims)
89
+ return out
90
+
91
+
92
+ def sage_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None):
93
+ required_in_pattern, required_out_pattern= "b n s d", "b n s d"
94
+ q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
95
+ out = sageattn(q, k, v, sm_scale=scale)
96
+ out = rearrange_out(out, out_pattern, required_out_pattern, dims)
97
+ return out
98
+
99
+
100
+ def xformers_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, scale=None):
101
+ required_in_pattern, required_out_pattern= "b s n d", "b s n d"
102
+ q, k, v = rearrange_qkv(q, k, v, q_pattern, k_pattern, v_pattern, required_in_pattern, dims)
103
+ out = xops.memory_efficient_attention(q, k, v, scale=scale)
104
+ out = rearrange_out(out, out_pattern, required_out_pattern, dims)
105
+ return out
106
+
107
+
108
+ def attention_forward(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="b n s d", k_pattern="b n s d", v_pattern="b n s d", out_pattern="b n s d", dims=None, attn_mask=None, scale=None, compatibility_mode=False):
109
+ if compatibility_mode or (attn_mask is not None):
110
+ return torch_sdpa(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, attn_mask=attn_mask, scale=scale)
111
+ else:
112
+ if ATTENTION_IMPLEMENTATION == "flash_attention_3":
113
+ return flash_attention_3(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
114
+ elif ATTENTION_IMPLEMENTATION == "flash_attention_2":
115
+ return flash_attention_2(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
116
+ elif ATTENTION_IMPLEMENTATION == "sage_attention":
117
+ return sage_attention(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
118
+ elif ATTENTION_IMPLEMENTATION == "xformers":
119
+ return xformers_attention(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
120
+ else:
121
+ return torch_sdpa(q, k, v, q_pattern, k_pattern, v_pattern, out_pattern, dims, scale=scale)
diffsynth/core/data/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .unified_dataset import UnifiedDataset
diffsynth/core/data/operators.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, torchvision, imageio, os
2
+ import imageio.v3 as iio
3
+ from PIL import Image
4
+
5
+
6
+ class DataProcessingPipeline:
7
+ def __init__(self, operators=None):
8
+ self.operators: list[DataProcessingOperator] = [] if operators is None else operators
9
+
10
+ def __call__(self, data):
11
+ for operator in self.operators:
12
+ data = operator(data)
13
+ return data
14
+
15
+ def __rshift__(self, pipe):
16
+ if isinstance(pipe, DataProcessingOperator):
17
+ pipe = DataProcessingPipeline([pipe])
18
+ return DataProcessingPipeline(self.operators + pipe.operators)
19
+
20
+
21
+ class DataProcessingOperator:
22
+ def __call__(self, data):
23
+ raise NotImplementedError("DataProcessingOperator cannot be called directly.")
24
+
25
+ def __rshift__(self, pipe):
26
+ if isinstance(pipe, DataProcessingOperator):
27
+ pipe = DataProcessingPipeline([pipe])
28
+ return DataProcessingPipeline([self]).__rshift__(pipe)
29
+
30
+
31
+ class DataProcessingOperatorRaw(DataProcessingOperator):
32
+ def __call__(self, data):
33
+ return data
34
+
35
+
36
+ class ToInt(DataProcessingOperator):
37
+ def __call__(self, data):
38
+ return int(data)
39
+
40
+
41
+ class ToFloat(DataProcessingOperator):
42
+ def __call__(self, data):
43
+ return float(data)
44
+
45
+
46
+ class ToStr(DataProcessingOperator):
47
+ def __init__(self, none_value=""):
48
+ self.none_value = none_value
49
+
50
+ def __call__(self, data):
51
+ if data is None: data = self.none_value
52
+ return str(data)
53
+
54
+
55
+ class LoadImage(DataProcessingOperator):
56
+ def __init__(self, convert_RGB=True):
57
+ self.convert_RGB = convert_RGB
58
+
59
+ def __call__(self, data: str):
60
+ image = Image.open(data)
61
+ if self.convert_RGB: image = image.convert("RGB")
62
+ return image
63
+
64
+
65
+ class ImageCropAndResize(DataProcessingOperator):
66
+ def __init__(self, height=None, width=None, max_pixels=None, height_division_factor=1, width_division_factor=1):
67
+ self.height = height
68
+ self.width = width
69
+ self.max_pixels = max_pixels
70
+ self.height_division_factor = height_division_factor
71
+ self.width_division_factor = width_division_factor
72
+
73
+ def crop_and_resize(self, image, target_height, target_width):
74
+ width, height = image.size
75
+ scale = max(target_width / width, target_height / height)
76
+ image = torchvision.transforms.functional.resize(
77
+ image,
78
+ (round(height*scale), round(width*scale)),
79
+ interpolation=torchvision.transforms.InterpolationMode.BILINEAR
80
+ )
81
+ image = torchvision.transforms.functional.center_crop(image, (target_height, target_width))
82
+ return image
83
+
84
+ def get_height_width(self, image):
85
+ if self.height is None or self.width is None:
86
+ width, height = image.size
87
+ if width * height > self.max_pixels:
88
+ scale = (width * height / self.max_pixels) ** 0.5
89
+ height, width = int(height / scale), int(width / scale)
90
+ height = height // self.height_division_factor * self.height_division_factor
91
+ width = width // self.width_division_factor * self.width_division_factor
92
+ else:
93
+ height, width = self.height, self.width
94
+ return height, width
95
+
96
+ def __call__(self, data: Image.Image):
97
+ image = self.crop_and_resize(data, *self.get_height_width(data))
98
+ return image
99
+
100
+
101
+ class ToList(DataProcessingOperator):
102
+ def __call__(self, data):
103
+ return [data]
104
+
105
+
106
+ class LoadVideo(DataProcessingOperator):
107
+ def __init__(self, num_frames=81, time_division_factor=4, time_division_remainder=1, frame_processor=lambda x: x):
108
+ self.num_frames = num_frames
109
+ self.time_division_factor = time_division_factor
110
+ self.time_division_remainder = time_division_remainder
111
+ # frame_processor is build in the video loader for high efficiency.
112
+ self.frame_processor = frame_processor
113
+
114
+ def get_num_frames(self, reader):
115
+ num_frames = self.num_frames
116
+ if int(reader.count_frames()) < num_frames:
117
+ num_frames = int(reader.count_frames())
118
+ while num_frames > 1 and num_frames % self.time_division_factor != self.time_division_remainder:
119
+ num_frames -= 1
120
+ return num_frames
121
+
122
+ def __call__(self, data: str):
123
+ reader = imageio.get_reader(data)
124
+ num_frames = self.get_num_frames(reader)
125
+ frames = []
126
+ for frame_id in range(num_frames):
127
+ frame = reader.get_data(frame_id)
128
+ frame = Image.fromarray(frame)
129
+ frame = self.frame_processor(frame)
130
+ frames.append(frame)
131
+ reader.close()
132
+ return frames
133
+
134
+
135
+ class SequencialProcess(DataProcessingOperator):
136
+ def __init__(self, operator=lambda x: x):
137
+ self.operator = operator
138
+
139
+ def __call__(self, data):
140
+ return [self.operator(i) for i in data]
141
+
142
+
143
+ class LoadGIF(DataProcessingOperator):
144
+ def __init__(self, num_frames=81, time_division_factor=4, time_division_remainder=1, frame_processor=lambda x: x):
145
+ self.num_frames = num_frames
146
+ self.time_division_factor = time_division_factor
147
+ self.time_division_remainder = time_division_remainder
148
+ # frame_processor is build in the video loader for high efficiency.
149
+ self.frame_processor = frame_processor
150
+
151
+ def get_num_frames(self, path):
152
+ num_frames = self.num_frames
153
+ images = iio.imread(path, mode="RGB")
154
+ if len(images) < num_frames:
155
+ num_frames = len(images)
156
+ while num_frames > 1 and num_frames % self.time_division_factor != self.time_division_remainder:
157
+ num_frames -= 1
158
+ return num_frames
159
+
160
+ def __call__(self, data: str):
161
+ num_frames = self.get_num_frames(data)
162
+ frames = []
163
+ images = iio.imread(data, mode="RGB")
164
+ for img in images:
165
+ frame = Image.fromarray(img)
166
+ frame = self.frame_processor(frame)
167
+ frames.append(frame)
168
+ if len(frames) >= num_frames:
169
+ break
170
+ return frames
171
+
172
+
173
+ class RouteByExtensionName(DataProcessingOperator):
174
+ def __init__(self, operator_map):
175
+ self.operator_map = operator_map
176
+
177
+ def __call__(self, data: str):
178
+ file_ext_name = data.split(".")[-1].lower()
179
+ for ext_names, operator in self.operator_map:
180
+ if ext_names is None or file_ext_name in ext_names:
181
+ return operator(data)
182
+ raise ValueError(f"Unsupported file: {data}")
183
+
184
+
185
+ class RouteByType(DataProcessingOperator):
186
+ def __init__(self, operator_map):
187
+ self.operator_map = operator_map
188
+
189
+ def __call__(self, data):
190
+ for dtype, operator in self.operator_map:
191
+ if dtype is None or isinstance(data, dtype):
192
+ return operator(data)
193
+ raise ValueError(f"Unsupported data: {data}")
194
+
195
+
196
+ class LoadTorchPickle(DataProcessingOperator):
197
+ def __init__(self, map_location="cpu"):
198
+ self.map_location = map_location
199
+
200
+ def __call__(self, data):
201
+ return torch.load(data, map_location=self.map_location, weights_only=False)
202
+
203
+
204
+ class ToAbsolutePath(DataProcessingOperator):
205
+ def __init__(self, base_path=""):
206
+ self.base_path = base_path
207
+
208
+ def __call__(self, data):
209
+ return os.path.join(self.base_path, data)
210
+
211
+
212
+ class LoadAudio(DataProcessingOperator):
213
+ def __init__(self, sr=16000):
214
+ self.sr = sr
215
+ def __call__(self, data: str):
216
+ import librosa
217
+ input_audio, sample_rate = librosa.load(data, sr=self.sr)
218
+ return input_audio
diffsynth/core/data/unified_dataset.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .operators import *
2
+ import torch, json, pandas
3
+
4
+
5
+ class UnifiedDataset(torch.utils.data.Dataset):
6
+ def __init__(
7
+ self,
8
+ base_path=None, metadata_path=None,
9
+ repeat=1,
10
+ data_file_keys=tuple(),
11
+ main_data_operator=lambda x: x,
12
+ special_operator_map=None,
13
+ ):
14
+ self.base_path = base_path
15
+ self.metadata_path = metadata_path
16
+ self.repeat = repeat
17
+ self.data_file_keys = data_file_keys
18
+ self.main_data_operator = main_data_operator
19
+ self.cached_data_operator = LoadTorchPickle()
20
+ self.special_operator_map = {} if special_operator_map is None else special_operator_map
21
+ self.data = []
22
+ self.cached_data = []
23
+ self.load_from_cache = metadata_path is None
24
+ self.load_metadata(metadata_path)
25
+
26
+ @staticmethod
27
+ def default_image_operator(
28
+ base_path="",
29
+ max_pixels=1920*1080, height=None, width=None,
30
+ height_division_factor=16, width_division_factor=16,
31
+ ):
32
+ return RouteByType(operator_map=[
33
+ (str, ToAbsolutePath(base_path) >> LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor)),
34
+ (list, SequencialProcess(ToAbsolutePath(base_path) >> LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor))),
35
+ ])
36
+
37
+ @staticmethod
38
+ def default_video_operator(
39
+ base_path="",
40
+ max_pixels=1920*1080, height=None, width=None,
41
+ height_division_factor=16, width_division_factor=16,
42
+ num_frames=81, time_division_factor=4, time_division_remainder=1,
43
+ ):
44
+ return RouteByType(operator_map=[
45
+ (str, ToAbsolutePath(base_path) >> RouteByExtensionName(operator_map=[
46
+ (("jpg", "jpeg", "png", "webp"), LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor) >> ToList()),
47
+ (("gif",), LoadGIF(
48
+ num_frames, time_division_factor, time_division_remainder,
49
+ frame_processor=ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor),
50
+ )),
51
+ (("mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"), LoadVideo(
52
+ num_frames, time_division_factor, time_division_remainder,
53
+ frame_processor=ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor),
54
+ )),
55
+ ])),
56
+ ])
57
+
58
+ def search_for_cached_data_files(self, path):
59
+ for file_name in os.listdir(path):
60
+ subpath = os.path.join(path, file_name)
61
+ if os.path.isdir(subpath):
62
+ self.search_for_cached_data_files(subpath)
63
+ elif subpath.endswith(".pth"):
64
+ self.cached_data.append(subpath)
65
+
66
+ def load_metadata(self, metadata_path):
67
+ if metadata_path is None:
68
+ print("No metadata_path. Searching for cached data files.")
69
+ self.search_for_cached_data_files(self.base_path)
70
+ print(f"{len(self.cached_data)} cached data files found.")
71
+ elif metadata_path.endswith(".json"):
72
+ with open(metadata_path, "r") as f:
73
+ metadata = json.load(f)
74
+ self.data = metadata
75
+ elif metadata_path.endswith(".jsonl"):
76
+ metadata = []
77
+ with open(metadata_path, 'r') as f:
78
+ for line in f:
79
+ metadata.append(json.loads(line.strip()))
80
+ self.data = metadata
81
+ else:
82
+ metadata = pandas.read_csv(metadata_path)
83
+ self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))]
84
+
85
+ def __getitem__(self, data_id):
86
+ if self.load_from_cache:
87
+ data = self.cached_data[data_id % len(self.cached_data)]
88
+ data = self.cached_data_operator(data)
89
+ else:
90
+ data = self.data[data_id % len(self.data)].copy()
91
+ for key in self.data_file_keys:
92
+ if key in data:
93
+ if key in self.special_operator_map:
94
+ data[key] = self.special_operator_map[key](data[key])
95
+ elif key in self.data_file_keys:
96
+ data[key] = self.main_data_operator(data[key])
97
+ return data
98
+
99
+ def __len__(self):
100
+ if self.load_from_cache:
101
+ return len(self.cached_data) * self.repeat
102
+ else:
103
+ return len(self.data) * self.repeat
104
+
105
+ def check_data_equal(self, data1, data2):
106
+ # Debug only
107
+ if len(data1) != len(data2):
108
+ return False
109
+ for k in data1:
110
+ if data1[k] != data2[k]:
111
+ return False
112
+ return True
diffsynth/core/gradient/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .gradient_checkpoint import gradient_checkpoint_forward
diffsynth/core/gradient/gradient_checkpoint.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def create_custom_forward(module):
5
+ def custom_forward(*inputs, **kwargs):
6
+ return module(*inputs, **kwargs)
7
+ return custom_forward
8
+
9
+
10
+ def gradient_checkpoint_forward(
11
+ model,
12
+ use_gradient_checkpointing,
13
+ use_gradient_checkpointing_offload,
14
+ *args,
15
+ **kwargs,
16
+ ):
17
+ if use_gradient_checkpointing_offload:
18
+ with torch.autograd.graph.save_on_cpu():
19
+ model_output = torch.utils.checkpoint.checkpoint(
20
+ create_custom_forward(model),
21
+ *args,
22
+ **kwargs,
23
+ use_reentrant=False,
24
+ )
25
+ elif use_gradient_checkpointing:
26
+ model_output = torch.utils.checkpoint.checkpoint(
27
+ create_custom_forward(model),
28
+ *args,
29
+ **kwargs,
30
+ use_reentrant=False,
31
+ )
32
+ else:
33
+ model_output = model(*args, **kwargs)
34
+ return model_output
diffsynth/core/loader/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .file import load_state_dict, hash_state_dict_keys, hash_model_file
2
+ from .model import load_model, load_model_with_disk_offload
3
+ from .config import ModelConfig
diffsynth/core/loader/config.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, glob, os
2
+ from typing import Optional, Union
3
+ from dataclasses import dataclass
4
+ from modelscope import snapshot_download
5
+ from huggingface_hub import snapshot_download as hf_snapshot_download
6
+ from typing import Optional
7
+
8
+
9
+ @dataclass
10
+ class ModelConfig:
11
+ path: Union[str, list[str]] = None
12
+ model_id: str = None
13
+ origin_file_pattern: Union[str, list[str]] = None
14
+ download_source: str = None
15
+ local_model_path: str = None
16
+ skip_download: bool = None
17
+ offload_device: Optional[Union[str, torch.device]] = None
18
+ offload_dtype: Optional[torch.dtype] = None
19
+ onload_device: Optional[Union[str, torch.device]] = None
20
+ onload_dtype: Optional[torch.dtype] = None
21
+ preparing_device: Optional[Union[str, torch.device]] = None
22
+ preparing_dtype: Optional[torch.dtype] = None
23
+ computation_device: Optional[Union[str, torch.device]] = None
24
+ computation_dtype: Optional[torch.dtype] = None
25
+ clear_parameters: bool = False
26
+
27
+ def check_input(self):
28
+ if self.path is None and self.model_id is None:
29
+ raise ValueError(f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`. `skip_download=True` only supports the first one.""")
30
+
31
+ def parse_original_file_pattern(self):
32
+ if self.origin_file_pattern is None or self.origin_file_pattern == "":
33
+ return "*"
34
+ elif self.origin_file_pattern.endswith("/"):
35
+ return self.origin_file_pattern + "*"
36
+ else:
37
+ return self.origin_file_pattern
38
+
39
+ def parse_download_source(self):
40
+ if self.download_source is None:
41
+ if os.environ.get('DIFFSYNTH_DOWNLOAD_SOURCE') is not None:
42
+ return os.environ.get('DIFFSYNTH_DOWNLOAD_SOURCE')
43
+ else:
44
+ return "modelscope"
45
+ else:
46
+ return self.download_source
47
+
48
+ def parse_skip_download(self):
49
+ if self.skip_download is None:
50
+ if os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD') is not None:
51
+ if os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD').lower() == "true":
52
+ return True
53
+ elif os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD').lower() == "false":
54
+ return False
55
+ else:
56
+ return False
57
+ else:
58
+ return self.skip_download
59
+
60
+ def download(self):
61
+ origin_file_pattern = self.parse_original_file_pattern()
62
+ downloaded_files = glob.glob(origin_file_pattern, root_dir=os.path.join(self.local_model_path, self.model_id))
63
+ download_source = self.parse_download_source()
64
+ if download_source.lower() == "modelscope":
65
+ snapshot_download(
66
+ self.model_id,
67
+ local_dir=os.path.join(self.local_model_path, self.model_id),
68
+ allow_file_pattern=origin_file_pattern,
69
+ ignore_file_pattern=downloaded_files,
70
+ local_files_only=False
71
+ )
72
+ elif download_source.lower() == "huggingface":
73
+ hf_snapshot_download(
74
+ self.model_id,
75
+ local_dir=os.path.join(self.local_model_path, self.model_id),
76
+ allow_patterns=origin_file_pattern,
77
+ ignore_patterns=downloaded_files,
78
+ local_files_only=False
79
+ )
80
+ else:
81
+ raise ValueError("`download_source` should be `modelscope` or `huggingface`.")
82
+
83
+ def require_downloading(self):
84
+ if self.path is not None:
85
+ return False
86
+ skip_download = self.parse_skip_download()
87
+ return not skip_download
88
+
89
+ def reset_local_model_path(self):
90
+ if os.environ.get('DIFFSYNTH_MODEL_BASE_PATH') is not None:
91
+ self.local_model_path = os.environ.get('DIFFSYNTH_MODEL_BASE_PATH')
92
+ elif self.local_model_path is None:
93
+ self.local_model_path = "./models"
94
+
95
+ def download_if_necessary(self):
96
+ self.check_input()
97
+ self.reset_local_model_path()
98
+ if self.require_downloading():
99
+ self.download()
100
+ if self.origin_file_pattern is None or self.origin_file_pattern == "":
101
+ self.path = os.path.join(self.local_model_path, self.model_id)
102
+ else:
103
+ self.path = glob.glob(os.path.join(self.local_model_path, self.model_id, self.origin_file_pattern))
104
+ if isinstance(self.path, list) and len(self.path) == 1:
105
+ self.path = self.path[0]
106
+
107
+ def vram_config(self):
108
+ return {
109
+ "offload_device": self.offload_device,
110
+ "offload_dtype": self.offload_dtype,
111
+ "onload_device": self.onload_device,
112
+ "onload_dtype": self.onload_dtype,
113
+ "preparing_device": self.preparing_device,
114
+ "preparing_dtype": self.preparing_dtype,
115
+ "computation_device": self.computation_device,
116
+ "computation_dtype": self.computation_dtype,
117
+ }
diffsynth/core/loader/file.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from safetensors import safe_open
2
+ import torch, hashlib
3
+
4
+
5
+ def load_state_dict(file_path, torch_dtype=None, device="cpu"):
6
+ if isinstance(file_path, list):
7
+ state_dict = {}
8
+ for file_path_ in file_path:
9
+ state_dict.update(load_state_dict(file_path_, torch_dtype, device))
10
+ return state_dict
11
+ if file_path.endswith(".safetensors"):
12
+ return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype, device=device)
13
+ else:
14
+ return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype, device=device)
15
+
16
+
17
+ def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"):
18
+ state_dict = {}
19
+ with safe_open(file_path, framework="pt", device=str(device)) as f:
20
+ for k in f.keys():
21
+ state_dict[k] = f.get_tensor(k)
22
+ if torch_dtype is not None:
23
+ state_dict[k] = state_dict[k].to(torch_dtype)
24
+ return state_dict
25
+
26
+
27
+ def load_state_dict_from_bin(file_path, torch_dtype=None, device="cpu"):
28
+ state_dict = torch.load(file_path, map_location=device, weights_only=True)
29
+ if len(state_dict) == 1:
30
+ if "state_dict" in state_dict:
31
+ state_dict = state_dict["state_dict"]
32
+ elif "module" in state_dict:
33
+ state_dict = state_dict["module"]
34
+ elif "model_state" in state_dict:
35
+ state_dict = state_dict["model_state"]
36
+ if torch_dtype is not None:
37
+ for i in state_dict:
38
+ if isinstance(state_dict[i], torch.Tensor):
39
+ state_dict[i] = state_dict[i].to(torch_dtype)
40
+ return state_dict
41
+
42
+
43
+ def convert_state_dict_keys_to_single_str(state_dict, with_shape=True):
44
+ keys = []
45
+ for key, value in state_dict.items():
46
+ if isinstance(key, str):
47
+ if isinstance(value, torch.Tensor):
48
+ if with_shape:
49
+ shape = "_".join(map(str, list(value.shape)))
50
+ keys.append(key + ":" + shape)
51
+ keys.append(key)
52
+ elif isinstance(value, dict):
53
+ keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape))
54
+ keys.sort()
55
+ keys_str = ",".join(keys)
56
+ return keys_str
57
+
58
+
59
+ def hash_state_dict_keys(state_dict, with_shape=True):
60
+ keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
61
+ keys_str = keys_str.encode(encoding="UTF-8")
62
+ return hashlib.md5(keys_str).hexdigest()
63
+
64
+
65
+ def load_keys_dict(file_path):
66
+ if isinstance(file_path, list):
67
+ state_dict = {}
68
+ for file_path_ in file_path:
69
+ state_dict.update(load_keys_dict(file_path_))
70
+ return state_dict
71
+ if file_path.endswith(".safetensors"):
72
+ return load_keys_dict_from_safetensors(file_path)
73
+ else:
74
+ return load_keys_dict_from_bin(file_path)
75
+
76
+
77
+ def load_keys_dict_from_safetensors(file_path):
78
+ keys_dict = {}
79
+ with safe_open(file_path, framework="pt", device="cpu") as f:
80
+ for k in f.keys():
81
+ keys_dict[k] = f.get_slice(k).get_shape()
82
+ return keys_dict
83
+
84
+
85
+ def convert_state_dict_to_keys_dict(state_dict):
86
+ keys_dict = {}
87
+ for k, v in state_dict.items():
88
+ if isinstance(v, torch.Tensor):
89
+ keys_dict[k] = list(v.shape)
90
+ else:
91
+ keys_dict[k] = convert_state_dict_to_keys_dict(v)
92
+ return keys_dict
93
+
94
+
95
+ def load_keys_dict_from_bin(file_path):
96
+ state_dict = load_state_dict_from_bin(file_path)
97
+ keys_dict = convert_state_dict_to_keys_dict(state_dict)
98
+ return keys_dict
99
+
100
+
101
+ def convert_keys_dict_to_single_str(state_dict, with_shape=True):
102
+ keys = []
103
+ for key, value in state_dict.items():
104
+ if isinstance(key, str):
105
+ if isinstance(value, dict):
106
+ keys.append(key + "|" + convert_keys_dict_to_single_str(value, with_shape=with_shape))
107
+ else:
108
+ if with_shape:
109
+ shape = "_".join(map(str, list(value)))
110
+ keys.append(key + ":" + shape)
111
+ keys.append(key)
112
+ keys.sort()
113
+ keys_str = ",".join(keys)
114
+ return keys_str
115
+
116
+
117
+ def hash_model_file(path, with_shape=True):
118
+ keys_dict = load_keys_dict(path)
119
+ keys_str = convert_keys_dict_to_single_str(keys_dict, with_shape=with_shape)
120
+ keys_str = keys_str.encode(encoding="UTF-8")
121
+ return hashlib.md5(keys_str).hexdigest()
diffsynth/core/loader/model.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..vram.initialization import skip_model_initialization
2
+ from ..vram.disk_map import DiskMap
3
+ from ..vram.layers import enable_vram_management
4
+ from .file import load_state_dict
5
+ import torch
6
+
7
+
8
+ def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, use_disk_map=False, module_map=None, vram_config=None, vram_limit=None):
9
+ config = {} if config is None else config
10
+ # Why do we use `skip_model_initialization`?
11
+ # It skips the random initialization of model parameters,
12
+ # thereby speeding up model loading and avoiding excessive memory usage.
13
+ with skip_model_initialization():
14
+ model = model_class(**config)
15
+ # What is `module_map`?
16
+ # This is a module mapping table for VRAM management.
17
+ if module_map is not None:
18
+ devices = [vram_config["offload_device"], vram_config["onload_device"], vram_config["preparing_device"], vram_config["computation_device"]]
19
+ device = [d for d in devices if d != "disk"][0]
20
+ dtypes = [vram_config["offload_dtype"], vram_config["onload_dtype"], vram_config["preparing_dtype"], vram_config["computation_dtype"]]
21
+ dtype = [d for d in dtypes if d != "disk"][0]
22
+ if vram_config["offload_device"] != "disk":
23
+ state_dict = DiskMap(path, device, torch_dtype=dtype)
24
+ if state_dict_converter is not None:
25
+ state_dict = state_dict_converter(state_dict)
26
+ else:
27
+ state_dict = {i: state_dict[i] for i in state_dict}
28
+ model.load_state_dict(state_dict, assign=True)
29
+ model = enable_vram_management(model, module_map, vram_config=vram_config, disk_map=None, vram_limit=vram_limit)
30
+ else:
31
+ disk_map = DiskMap(path, device, state_dict_converter=state_dict_converter)
32
+ model = enable_vram_management(model, module_map, vram_config=vram_config, disk_map=disk_map, vram_limit=vram_limit)
33
+ else:
34
+ # Why do we use `DiskMap`?
35
+ # Sometimes a model file contains multiple models,
36
+ # and DiskMap can load only the parameters of a single model,
37
+ # avoiding the need to load all parameters in the file.
38
+ if use_disk_map:
39
+ state_dict = DiskMap(path, device, torch_dtype=torch_dtype)
40
+ else:
41
+ state_dict = load_state_dict(path, torch_dtype, device)
42
+ # Why do we use `state_dict_converter`?
43
+ # Some models are saved in complex formats,
44
+ # and we need to convert the state dict into the appropriate format.
45
+ if state_dict_converter is not None:
46
+ state_dict = state_dict_converter(state_dict)
47
+ else:
48
+ state_dict = {i: state_dict[i] for i in state_dict}
49
+ model.load_state_dict(state_dict, assign=True)
50
+ # Why do we call `to()`?
51
+ # Because some models override the behavior of `to()`,
52
+ # especially those from libraries like Transformers.
53
+ model = model.to(dtype=torch_dtype, device=device)
54
+ if hasattr(model, "eval"):
55
+ model = model.eval()
56
+ return model
57
+
58
+
59
+ def load_model_with_disk_offload(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, module_map=None):
60
+ if isinstance(path, str):
61
+ path = [path]
62
+ config = {} if config is None else config
63
+ with skip_model_initialization():
64
+ model = model_class(**config)
65
+ if hasattr(model, "eval"):
66
+ model = model.eval()
67
+ disk_map = DiskMap(path, device, state_dict_converter=state_dict_converter)
68
+ vram_config = {
69
+ "offload_dtype": "disk",
70
+ "offload_device": "disk",
71
+ "onload_dtype": "disk",
72
+ "onload_device": "disk",
73
+ "preparing_dtype": torch.float8_e4m3fn,
74
+ "preparing_device": device,
75
+ "computation_dtype": torch_dtype,
76
+ "computation_device": device,
77
+ }
78
+ enable_vram_management(model, module_map, vram_config=vram_config, disk_map=disk_map, vram_limit=80)
79
+ return model
diffsynth/core/vram/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .initialization import skip_model_initialization
2
+ from .layers import *
diffsynth/core/vram/disk_map.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from safetensors import safe_open
2
+ import torch, os
3
+
4
+
5
+ class SafetensorsCompatibleTensor:
6
+ def __init__(self, tensor):
7
+ self.tensor = tensor
8
+
9
+ def get_shape(self):
10
+ return list(self.tensor.shape)
11
+
12
+
13
+ class SafetensorsCompatibleBinaryLoader:
14
+ def __init__(self, path, device):
15
+ print("Detected non-safetensors files, which may cause slower loading. It's recommended to convert it to a safetensors file.")
16
+ self.state_dict = torch.load(path, weights_only=True, map_location=device)
17
+
18
+ def keys(self):
19
+ return self.state_dict.keys()
20
+
21
+ def get_tensor(self, name):
22
+ return self.state_dict[name]
23
+
24
+ def get_slice(self, name):
25
+ return SafetensorsCompatibleTensor(self.state_dict[name])
26
+
27
+
28
+ class DiskMap:
29
+
30
+ def __init__(self, path, device, torch_dtype=None, state_dict_converter=None, buffer_size=10**9):
31
+ self.path = path if isinstance(path, list) else [path]
32
+ self.device = device
33
+ self.torch_dtype = torch_dtype
34
+ if os.environ.get('DIFFSYNTH_DISK_MAP_BUFFER_SIZE') is not None:
35
+ self.buffer_size = int(os.environ.get('DIFFSYNTH_DISK_MAP_BUFFER_SIZE'))
36
+ else:
37
+ self.buffer_size = buffer_size
38
+ self.files = []
39
+ self.flush_files()
40
+ self.name_map = {}
41
+ for file_id, file in enumerate(self.files):
42
+ for name in file.keys():
43
+ self.name_map[name] = file_id
44
+ self.rename_dict = self.fetch_rename_dict(state_dict_converter)
45
+
46
+ def flush_files(self):
47
+ if len(self.files) == 0:
48
+ for path in self.path:
49
+ if path.endswith(".safetensors"):
50
+ self.files.append(safe_open(path, framework="pt", device=str(self.device)))
51
+ else:
52
+ self.files.append(SafetensorsCompatibleBinaryLoader(path, device=self.device))
53
+ else:
54
+ for i, path in enumerate(self.path):
55
+ if path.endswith(".safetensors"):
56
+ self.files[i] = safe_open(path, framework="pt", device=str(self.device))
57
+ self.num_params = 0
58
+
59
+ def __getitem__(self, name):
60
+ if self.rename_dict is not None: name = self.rename_dict[name]
61
+ file_id = self.name_map[name]
62
+ param = self.files[file_id].get_tensor(name)
63
+ if self.torch_dtype is not None and isinstance(param, torch.Tensor):
64
+ param = param.to(self.torch_dtype)
65
+ if isinstance(param, torch.Tensor) and param.device == "cpu":
66
+ param = param.clone()
67
+ if isinstance(param, torch.Tensor):
68
+ self.num_params += param.numel()
69
+ if self.num_params > self.buffer_size:
70
+ self.flush_files()
71
+ return param
72
+
73
+ def fetch_rename_dict(self, state_dict_converter):
74
+ if state_dict_converter is None:
75
+ return None
76
+ state_dict = {}
77
+ for file in self.files:
78
+ for name in file.keys():
79
+ state_dict[name] = name
80
+ state_dict = state_dict_converter(state_dict)
81
+ return state_dict
82
+
83
+ def __iter__(self):
84
+ if self.rename_dict is not None:
85
+ return self.rename_dict.__iter__()
86
+ else:
87
+ return self.name_map.__iter__()
88
+
89
+ def __contains__(self, x):
90
+ if self.rename_dict is not None:
91
+ return x in self.rename_dict
92
+ else:
93
+ return x in self.name_map
diffsynth/core/vram/initialization.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from contextlib import contextmanager
3
+
4
+
5
+ @contextmanager
6
+ def skip_model_initialization(device=torch.device("meta")):
7
+
8
+ def register_empty_parameter(module, name, param):
9
+ old_register_parameter(module, name, param)
10
+ if param is not None:
11
+ param_cls = type(module._parameters[name])
12
+ kwargs = module._parameters[name].__dict__
13
+ kwargs["requires_grad"] = param.requires_grad
14
+ module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
15
+
16
+ old_register_parameter = torch.nn.Module.register_parameter
17
+ torch.nn.Module.register_parameter = register_empty_parameter
18
+ try:
19
+ yield
20
+ finally:
21
+ torch.nn.Module.register_parameter = old_register_parameter
diffsynth/core/vram/layers.py ADDED
@@ -0,0 +1,475 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, copy
2
+ from typing import Union
3
+ from .initialization import skip_model_initialization
4
+ from .disk_map import DiskMap
5
+
6
+
7
+ class AutoTorchModule(torch.nn.Module):
8
+
9
+ def __init__(
10
+ self,
11
+ offload_dtype: torch.dtype = None,
12
+ offload_device: Union[str, torch.device] = None,
13
+ onload_dtype: torch.dtype = None,
14
+ onload_device: Union[str, torch.device] = None,
15
+ preparing_dtype: torch.dtype = None,
16
+ preparing_device: Union[str, torch.device] = None,
17
+ computation_dtype: torch.dtype = None,
18
+ computation_device: Union[str, torch.device] = None,
19
+ vram_limit: float = None,
20
+ ):
21
+ super().__init__()
22
+ self.set_dtype_and_device(
23
+ offload_dtype,
24
+ offload_device,
25
+ onload_dtype,
26
+ onload_device,
27
+ preparing_dtype,
28
+ preparing_device,
29
+ computation_dtype,
30
+ computation_device,
31
+ vram_limit,
32
+ )
33
+ self.state = 0
34
+ self.name = ""
35
+
36
+ def set_dtype_and_device(
37
+ self,
38
+ offload_dtype: torch.dtype = None,
39
+ offload_device: Union[str, torch.device] = None,
40
+ onload_dtype: torch.dtype = None,
41
+ onload_device: Union[str, torch.device] = None,
42
+ preparing_dtype: torch.dtype = None,
43
+ preparing_device: Union[str, torch.device] = None,
44
+ computation_dtype: torch.dtype = None,
45
+ computation_device: Union[str, torch.device] = None,
46
+ vram_limit: float = None,
47
+ ):
48
+ self.offload_dtype = offload_dtype or computation_dtype
49
+ self.offload_device = offload_device or computation_dtype
50
+ self.onload_dtype = onload_dtype or computation_dtype
51
+ self.onload_device = onload_device or computation_dtype
52
+ self.preparing_dtype = preparing_dtype or computation_dtype
53
+ self.preparing_device = preparing_device or computation_dtype
54
+ self.computation_dtype = computation_dtype
55
+ self.computation_device = computation_device
56
+ self.vram_limit = vram_limit
57
+
58
+ def cast_to(self, weight, dtype, device):
59
+ r = torch.empty_like(weight, dtype=dtype, device=device)
60
+ r.copy_(weight)
61
+ return r
62
+
63
+ def check_free_vram(self):
64
+ gpu_mem_state = torch.cuda.mem_get_info(self.computation_device)
65
+ used_memory = (gpu_mem_state[1] - gpu_mem_state[0]) / (1024**3)
66
+ return used_memory < self.vram_limit
67
+
68
+ def offload(self):
69
+ if self.state != 0:
70
+ self.to(dtype=self.offload_dtype, device=self.offload_device)
71
+ self.state = 0
72
+
73
+ def onload(self):
74
+ if self.state != 1:
75
+ self.to(dtype=self.onload_dtype, device=self.onload_device)
76
+ self.state = 1
77
+
78
+ def param_name(self, name):
79
+ if self.name == "":
80
+ return name
81
+ else:
82
+ return self.name + "." + name
83
+
84
+
85
+ class AutoWrappedModule(AutoTorchModule):
86
+
87
+ def __init__(
88
+ self,
89
+ module: torch.nn.Module,
90
+ offload_dtype: torch.dtype = None,
91
+ offload_device: Union[str, torch.device] = None,
92
+ onload_dtype: torch.dtype = None,
93
+ onload_device: Union[str, torch.device] = None,
94
+ preparing_dtype: torch.dtype = None,
95
+ preparing_device: Union[str, torch.device] = None,
96
+ computation_dtype: torch.dtype = None,
97
+ computation_device: Union[str, torch.device] = None,
98
+ vram_limit: float = None,
99
+ name: str = "",
100
+ disk_map: DiskMap = None,
101
+ **kwargs
102
+ ):
103
+ super().__init__(
104
+ offload_dtype,
105
+ offload_device,
106
+ onload_dtype,
107
+ onload_device,
108
+ preparing_dtype,
109
+ preparing_device,
110
+ computation_dtype,
111
+ computation_device,
112
+ vram_limit,
113
+ )
114
+ self.module = module
115
+ if offload_dtype == "disk":
116
+ self.name = name
117
+ self.disk_map = disk_map
118
+ self.required_params = [name for name, _ in self.module.named_parameters()]
119
+ self.disk_offload = True
120
+ else:
121
+ self.disk_offload = False
122
+
123
+ def load_from_disk(self, torch_dtype, device, copy_module=False):
124
+ if copy_module:
125
+ module = copy.deepcopy(self.module)
126
+ else:
127
+ module = self.module
128
+ state_dict = {}
129
+ for name in self.required_params:
130
+ param = self.disk_map[self.param_name(name)]
131
+ param = param.to(dtype=torch_dtype, device=device)
132
+ state_dict[name] = param
133
+ module.load_state_dict(state_dict, assign=True)
134
+ module.to(dtype=torch_dtype, device=device)
135
+ return module
136
+
137
+ def offload_to_disk(self, model: torch.nn.Module):
138
+ for buf in model.buffers():
139
+ # If there are some parameters are registed in buffers (not in state dict),
140
+ # We cannot offload the model.
141
+ for children in model.children():
142
+ self.offload_to_disk(children)
143
+ break
144
+ else:
145
+ model.to("meta")
146
+
147
+ def offload(self):
148
+ # offload / onload / preparing -> offload
149
+ if self.state != 0:
150
+ if self.disk_offload:
151
+ self.offload_to_disk(self.module)
152
+ else:
153
+ self.to(dtype=self.offload_dtype, device=self.offload_device)
154
+ self.state = 0
155
+
156
+ def onload(self):
157
+ # offload / onload / preparing -> onload
158
+ if self.state < 1:
159
+ if self.disk_offload and self.onload_device != "disk" and self.offload_device == "disk":
160
+ self.load_from_disk(self.onload_dtype, self.onload_device)
161
+ elif self.onload_device != "disk":
162
+ self.to(dtype=self.onload_dtype, device=self.onload_device)
163
+ self.state = 1
164
+
165
+ def preparing(self):
166
+ # onload / preparing -> preparing
167
+ if self.state != 2:
168
+ if self.disk_offload and self.preparing_device != "disk" and self.onload_device == "disk":
169
+ self.load_from_disk(self.preparing_dtype, self.preparing_device)
170
+ elif self.preparing_device != "disk":
171
+ self.to(dtype=self.preparing_dtype, device=self.preparing_device)
172
+ self.state = 2
173
+
174
+ def cast_to(self, module, dtype, device):
175
+ return copy.deepcopy(module).to(dtype=dtype, device=device)
176
+
177
+ def computation(self):
178
+ # onload / preparing -> computation (temporary)
179
+ if self.state == 2:
180
+ torch_dtype, device = self.preparing_dtype, self.preparing_device
181
+ else:
182
+ torch_dtype, device = self.onload_dtype, self.onload_device
183
+ if torch_dtype == self.computation_dtype and device == self.computation_device:
184
+ module = self.module
185
+ elif self.disk_offload and device == "disk":
186
+ module = self.load_from_disk(self.computation_dtype, self.computation_device, copy_module=True)
187
+ else:
188
+ module = self.cast_to(self.module, dtype=self.computation_dtype, device=self.computation_device)
189
+ return module
190
+
191
+ def forward(self, *args, **kwargs):
192
+ if self.state == 1 and (self.vram_limit is None or self.check_free_vram()):
193
+ self.preparing()
194
+ module = self.computation()
195
+ return module(*args, **kwargs)
196
+
197
+ def __getattr__(self, name):
198
+ if name in self.__dict__ or name == "module":
199
+ return super().__getattr__(name)
200
+ else:
201
+ return getattr(self.module, name)
202
+
203
+
204
+ class AutoWrappedNonRecurseModule(AutoWrappedModule):
205
+
206
+ def __init__(
207
+ self,
208
+ module: torch.nn.Module,
209
+ offload_dtype: torch.dtype = None,
210
+ offload_device: Union[str, torch.device] = None,
211
+ onload_dtype: torch.dtype = None,
212
+ onload_device: Union[str, torch.device] = None,
213
+ preparing_dtype: torch.dtype = None,
214
+ preparing_device: Union[str, torch.device] = None,
215
+ computation_dtype: torch.dtype = None,
216
+ computation_device: Union[str, torch.device] = None,
217
+ vram_limit: float = None,
218
+ name: str = "",
219
+ disk_map: DiskMap = None,
220
+ **kwargs
221
+ ):
222
+ super().__init__(
223
+ module,
224
+ offload_dtype,
225
+ offload_device,
226
+ onload_dtype,
227
+ onload_device,
228
+ preparing_dtype,
229
+ preparing_device,
230
+ computation_dtype,
231
+ computation_device,
232
+ vram_limit,
233
+ name,
234
+ disk_map,
235
+ **kwargs
236
+ )
237
+ if self.disk_offload:
238
+ self.required_params = [name for name, _ in self.module.named_parameters(recurse=False)]
239
+
240
+ def load_from_disk(self, torch_dtype, device, copy_module=False):
241
+ if copy_module:
242
+ module = copy.deepcopy(self.module)
243
+ else:
244
+ module = self.module
245
+ state_dict = {}
246
+ for name in self.required_params:
247
+ param = self.disk_map[self.param_name(name)]
248
+ param = param.to(dtype=torch_dtype, device=device)
249
+ state_dict[name] = param
250
+ module.load_state_dict(state_dict, assign=True, strict=False)
251
+ return module
252
+
253
+ def offload_to_disk(self, model: torch.nn.Module):
254
+ for name in self.required_params:
255
+ getattr(self, name).to("meta")
256
+
257
+ def cast_to(self, module, dtype, device):
258
+ # Parameter casting is implemented in the model architecture.
259
+ return module
260
+
261
+ def __getattr__(self, name):
262
+ if name in self.__dict__ or name == "module":
263
+ return super().__getattr__(name)
264
+ else:
265
+ return getattr(self.module, name)
266
+
267
+
268
+ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
269
+ def __init__(
270
+ self,
271
+ module: torch.nn.Linear,
272
+ offload_dtype: torch.dtype = None,
273
+ offload_device: Union[str, torch.device] = None,
274
+ onload_dtype: torch.dtype = None,
275
+ onload_device: Union[str, torch.device] = None,
276
+ preparing_dtype: torch.dtype = None,
277
+ preparing_device: Union[str, torch.device] = None,
278
+ computation_dtype: torch.dtype = None,
279
+ computation_device: Union[str, torch.device] = None,
280
+ vram_limit: float = None,
281
+ name: str = "",
282
+ disk_map: DiskMap = None,
283
+ **kwargs
284
+ ):
285
+ with skip_model_initialization():
286
+ super().__init__(
287
+ in_features=module.in_features,
288
+ out_features=module.out_features,
289
+ bias=module.bias is not None,
290
+ )
291
+ self.set_dtype_and_device(
292
+ offload_dtype,
293
+ offload_device,
294
+ onload_dtype,
295
+ onload_device,
296
+ preparing_dtype,
297
+ preparing_device,
298
+ computation_dtype,
299
+ computation_device,
300
+ vram_limit,
301
+ )
302
+ self.weight = module.weight
303
+ self.bias = module.bias
304
+ self.state = 0
305
+ self.name = name
306
+ self.lora_A_weights = []
307
+ self.lora_B_weights = []
308
+ self.lora_merger = None
309
+ self.enable_fp8 = computation_dtype in [torch.float8_e4m3fn, torch.float8_e4m3fnuz]
310
+
311
+ if offload_dtype == "disk":
312
+ self.disk_map = disk_map
313
+ self.disk_offload = True
314
+ else:
315
+ self.disk_offload = False
316
+
317
+ def fp8_linear(
318
+ self,
319
+ input: torch.Tensor,
320
+ weight: torch.Tensor,
321
+ bias: torch.Tensor = None,
322
+ ) -> torch.Tensor:
323
+ device = input.device
324
+ origin_dtype = input.dtype
325
+ origin_shape = input.shape
326
+ input = input.reshape(-1, origin_shape[-1])
327
+
328
+ x_max = torch.max(torch.abs(input), dim=-1, keepdim=True).values
329
+ fp8_max = 448.0
330
+ # For float8_e4m3fnuz, the maximum representable value is half of that of e4m3fn.
331
+ # To avoid overflow and ensure numerical compatibility during FP8 computation,
332
+ # we scale down the input by 2.0 in advance.
333
+ # This scaling will be compensated later during the final result scaling.
334
+ if self.computation_dtype == torch.float8_e4m3fnuz:
335
+ fp8_max = fp8_max / 2.0
336
+ scale_a = torch.clamp(x_max / fp8_max, min=1.0).float().to(device=device)
337
+ scale_b = torch.ones((weight.shape[0], 1)).to(device=device)
338
+ input = input / (scale_a + 1e-8)
339
+ input = input.to(self.computation_dtype)
340
+ weight = weight.to(self.computation_dtype)
341
+ bias = bias.to(torch.bfloat16)
342
+
343
+ result = torch._scaled_mm(
344
+ input,
345
+ weight.T,
346
+ scale_a=scale_a,
347
+ scale_b=scale_b.T,
348
+ bias=bias,
349
+ out_dtype=origin_dtype,
350
+ )
351
+ new_shape = origin_shape[:-1] + result.shape[-1:]
352
+ result = result.reshape(new_shape)
353
+ return result
354
+
355
+ def load_from_disk(self, torch_dtype, device, assign=True):
356
+ weight = self.disk_map[self.name + ".weight"].to(dtype=torch_dtype, device=device)
357
+ bias = None if self.bias is None else self.disk_map[self.name + ".bias"].to(dtype=torch_dtype, device=device)
358
+ if assign:
359
+ state_dict = {"weight": weight}
360
+ if bias is not None: state_dict["bias"] = bias
361
+ self.load_state_dict(state_dict, assign=True)
362
+ return weight, bias
363
+
364
+ def offload(self):
365
+ # offload / onload / preparing -> offload
366
+ if self.state != 0:
367
+ if self.disk_offload:
368
+ self.to("meta")
369
+ else:
370
+ self.to(dtype=self.offload_dtype, device=self.offload_device)
371
+ self.state = 0
372
+
373
+ def onload(self):
374
+ # offload / onload / preparing -> onload
375
+ if self.state < 1:
376
+ if self.disk_offload and self.onload_device != "disk" and self.offload_device == "disk":
377
+ self.load_from_disk(self.onload_dtype, self.onload_device)
378
+ elif self.onload_device != "disk":
379
+ self.to(dtype=self.onload_dtype, device=self.onload_device)
380
+ self.state = 1
381
+
382
+ def preparing(self):
383
+ # onload / preparing -> preparing
384
+ if self.state != 2:
385
+ if self.disk_offload and self.preparing_device != "disk" and self.onload_device == "disk":
386
+ self.load_from_disk(self.preparing_dtype, self.preparing_device)
387
+ elif self.preparing_device != "disk":
388
+ self.to(dtype=self.preparing_dtype, device=self.preparing_device)
389
+ self.state = 2
390
+
391
+ def computation(self):
392
+ # onload / preparing -> computation (temporary)
393
+ if self.state == 2:
394
+ torch_dtype, device = self.preparing_dtype, self.preparing_device
395
+ else:
396
+ torch_dtype, device = self.onload_dtype, self.onload_device
397
+ if torch_dtype == self.computation_dtype and device == self.computation_device:
398
+ weight, bias = self.weight, self.bias
399
+ elif self.disk_offload and device == "disk":
400
+ weight, bias = self.load_from_disk(self.computation_dtype, self.computation_device, assign=False)
401
+ else:
402
+ weight = self.cast_to(self.weight, self.computation_dtype, self.computation_device)
403
+ bias = None if self.bias is None else self.cast_to(self.bias, self.computation_dtype, self.computation_device)
404
+ return weight, bias
405
+
406
+ def linear_forward(self, x, weight, bias):
407
+ if self.enable_fp8:
408
+ out = self.fp8_linear(x, weight, bias)
409
+ else:
410
+ out = torch.nn.functional.linear(x, weight, bias)
411
+ return out
412
+
413
+ def lora_forward(self, x, out):
414
+ if self.lora_merger is None:
415
+ for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights):
416
+ out = out + x @ lora_A.T @ lora_B.T
417
+ else:
418
+ lora_output = []
419
+ for lora_A, lora_B in zip(self.lora_A_weights, self.lora_B_weights):
420
+ lora_output.append(x @ lora_A.T @ lora_B.T)
421
+ lora_output = torch.stack(lora_output)
422
+ out = self.lora_merger(out, lora_output)
423
+ return out
424
+
425
+ def forward(self, x, *args, **kwargs):
426
+ if self.state == 1 and (self.vram_limit is None or self.check_free_vram()):
427
+ self.preparing()
428
+ weight, bias = self.computation()
429
+ out = self.linear_forward(x, weight, bias)
430
+ if len(self.lora_A_weights) > 0:
431
+ out = self.lora_forward(x, out)
432
+ return out
433
+
434
+
435
+ def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, vram_config: dict, vram_limit=None, name_prefix="", disk_map=None, **kwargs):
436
+ if isinstance(model, AutoWrappedNonRecurseModule):
437
+ model = model.module
438
+ for name, module in model.named_children():
439
+ layer_name = name if name_prefix == "" else name_prefix + "." + name
440
+ for source_module, target_module in module_map.items():
441
+ if isinstance(module, source_module):
442
+ module_ = target_module(module, **vram_config, vram_limit=vram_limit, name=layer_name, disk_map=disk_map, **kwargs)
443
+ if isinstance(module_, AutoWrappedNonRecurseModule):
444
+ enable_vram_management_recursively(module_, module_map, vram_config, vram_limit=vram_limit, name_prefix=layer_name, disk_map=disk_map, **kwargs)
445
+ setattr(model, name, module_)
446
+ break
447
+ else:
448
+ enable_vram_management_recursively(module, module_map, vram_config, vram_limit=vram_limit, name_prefix=layer_name, disk_map=disk_map, **kwargs)
449
+
450
+
451
+ def fill_vram_config(model, vram_config):
452
+ vram_config_ = vram_config.copy()
453
+ vram_config_["onload_dtype"] = vram_config["computation_dtype"]
454
+ vram_config_["onload_device"] = vram_config["computation_device"]
455
+ vram_config_["preparing_dtype"] = vram_config["computation_dtype"]
456
+ vram_config_["preparing_device"] = vram_config["computation_device"]
457
+ for k in vram_config:
458
+ if vram_config[k] != vram_config_[k]:
459
+ print(f"No fine-grained VRAM configuration is provided for {model.__class__.__name__}. [`onload`, `preparing`, `computation`] will be the same state. `vram_config` is set to {vram_config_}")
460
+ break
461
+ return vram_config_
462
+
463
+
464
+ def enable_vram_management(model: torch.nn.Module, module_map: dict, vram_config: dict, vram_limit=None, disk_map=None, **kwargs):
465
+ for source_module, target_module in module_map.items():
466
+ # If no fine-grained VRAM configuration is provided, the entire model will be managed uniformly.
467
+ if isinstance(model, source_module):
468
+ vram_config = fill_vram_config(model, vram_config)
469
+ model = target_module(model, **vram_config, vram_limit=vram_limit, disk_map=disk_map, **kwargs)
470
+ break
471
+ else:
472
+ enable_vram_management_recursively(model, module_map, vram_config, vram_limit=vram_limit, disk_map=disk_map, **kwargs)
473
+ # `vram_management_enabled` is a flag that allows the pipeline to determine whether VRAM management is enabled.
474
+ model.vram_management_enabled = True
475
+ return model
diffsynth/datasets/mvdataset.py ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import imageio, os, torch, warnings, torchvision, argparse, json
2
+ from peft import LoraConfig, inject_adapter_in_model
3
+ from PIL import Image
4
+ import pandas as pd
5
+ from tqdm import tqdm
6
+ from accelerate import Accelerator
7
+ from accelerate.utils import DistributedDataParallelKwargs
8
+ import random
9
+ from decord import VideoReader
10
+ from decord import cpu, gpu
11
+ import imageio.v3 as iio
12
+
13
+ from torchvision import transforms
14
+ import torchvision
15
+ import random
16
+ import decord
17
+ from torchvision import transforms
18
+ import re
19
+ decord.bridge.set_bridge('torch')
20
+ import random
21
+ import numpy as np
22
+ from PIL import Image, ImageOps
23
+
24
+ class MulltiShot_MultiView_Dataset(torch.utils.data.Dataset):
25
+ def __init__(self, dataset_base_path='/root/paddlejob/workspace/qizipeng/baidu/personal-code/Multi-view/multi_view/datasets/merged_mark_paishe_ds_meg_merge_dwposefilter_paishe.json',
26
+ ref_image_path='/root/paddlejob/workspace/qizipeng/code/longvideogen/output.json',
27
+ time_division_factor=4,
28
+ time_division_remainder=1,
29
+ max_pixels=1920*1080,
30
+ height_division_factor=16, width_division_factor=16,
31
+ transform=None,
32
+ length=None,
33
+ resolution=None,
34
+ prev_length=5,
35
+ ref_num = 3,
36
+ training = True):
37
+ self.data_path = dataset_base_path
38
+ self.data = []
39
+ self.length = length
40
+ self.resolution = resolution
41
+ self.height, self.width = resolution
42
+ self.num_frames = length
43
+ self.time_division_factor = time_division_factor
44
+ self.time_division_remainder = time_division_remainder
45
+ self.max_pixels = max_pixels
46
+ self.height_division_factor = height_division_factor
47
+ self.width_division_factor = width_division_factor
48
+ self.prev_length = prev_length
49
+ self.training = training
50
+ self.ref_num = ref_num
51
+
52
+ with open(self.data_path, 'r') as f:
53
+ meta_datas = json.load(f)
54
+
55
+ for video_path in tqdm(meta_datas.keys()):
56
+ context = meta_datas[video_path]
57
+ candidate_labels = list(context.keys())
58
+ candidate_labels.remove('text')
59
+
60
+ disk_path = meta_datas[video_path]["disk_path"]
61
+ if not disk_path.lower().endswith(".mp4"):
62
+ continue
63
+
64
+
65
+ # reader = imageio.get_reader(meta_datas[video_path]["disk_path"])
66
+ # total_original_frames = reader.count_frames()
67
+ # total_frame = total_original_frames # context["end_index"] - context["start_index"] - 1
68
+ total_frame = None
69
+ ref_id = self.get_ref_id(face_crop_angle = context['facedetect_v1'], facedetect_v1_frame_index = context['facedetect_v1_frame_index'], total_frame = total_frame)
70
+ if ref_id == []:
71
+ continue
72
+ ref_id_all = []
73
+ for ids in ref_id:
74
+ ref_id_grop = []
75
+ for id in ids:
76
+ coordinate = context['facedetect_v1'][id][0]['detect']
77
+ if context['facedetect_v1'][id][0]['detect']["prob"] < 0.99:
78
+ continue
79
+ top, height, width, left = coordinate['top'], coordinate['height'], coordinate['width'], coordinate['left']
80
+ if not(min(height, width) > 80 ):
81
+ continue
82
+ # enlarge bbox 1.5x
83
+ width = int(width * 1)
84
+ height = int(height * 1)
85
+ frame_index = context['facedetect_v1_frame_index'][id]
86
+ ref_id_grop.append([top, height, width, left, int(frame_index)])
87
+ if ref_id_grop != []:
88
+ if len(ref_id_grop) >= 3: #self.ref_num: ### 为了和ref_num = 3 保持数据一致
89
+ ref_id_all.append(ref_id_grop)
90
+ if ref_id_all == []:
91
+ continue
92
+ meta_prompt = {}
93
+ meta_prompt["global_caption"] = None
94
+ meta_prompt["per_shot_prompt"] = []
95
+ meta_prompt["single_prompt"] = context['text']
96
+ self.data.append({'video_path': disk_path, 'meta_prompt': meta_prompt, "ref_id_all": ref_id_all})
97
+ # self.data.append({'video_path':video_path, 'meta_prompt': meta_prompt, "ref_id_all": ref_id_all})
98
+
99
+ random.seed(42) # 让每次划分一致(可选)
100
+ total = len(self.data)
101
+ test_count = max(1, int(total * 0.05)) # 至少一个
102
+
103
+ # 随机选择 test 的 index
104
+ test_indices = set(random.sample(range(total), test_count))
105
+
106
+ self.data_test = [self.data[i] for i in range(total) if i in test_indices]
107
+ self.data_train = [self.data[i] for i in range(total) if i not in test_indices]
108
+ print(f"🔥 数据集划分完成:Train={len(self.data_train)}, Test={len(self.data_test)}")
109
+
110
+ if self.height is not None and self.width is not None:
111
+ print("Height and width are fixed. Setting `dynamic_resolution` to False.")
112
+ self.dynamic_resolution = False
113
+ elif self.height is None and self.width is None:
114
+ print("Height and width are none. Setting `dynamic_resolution` to True.")
115
+ self.dynamic_resolution = True
116
+
117
+ def get_ref_id(self, face_crop_angle, facedetect_v1_frame_index = None, total_frame = None, angle_threshold=50):
118
+ """
119
+ 返回满足角度差异要求的三元组 [i, j, k]
120
+ 要求:
121
+ - face_crop_angle[i] / [j] / [k] 都必须非空
122
+ - i,j 两者任意 yaw/pitch/roll 差值 > angle_threshold
123
+ - k != i != j,且 k 也必须非空
124
+ """
125
+ ref_id = []
126
+ max_try = 5
127
+ need_max = 3
128
+ try_num = 0
129
+
130
+ # 过滤空元素,保留有效索引
131
+ valid_indices = [idx for idx, item in enumerate(face_crop_angle) if item]
132
+ N = len(valid_indices)
133
+
134
+ if N < 3:
135
+ return ref_id # 不足 3 张有效图,无法组成三元组
136
+
137
+ # 两两组合检查角度差
138
+ for a in range(N - 1):
139
+ i = valid_indices[a]
140
+ # if facedetect_v1_frame_index[i] > total_frame:
141
+ # continue
142
+ angle_i = face_crop_angle[i][0]["angle"]
143
+
144
+ for b in range(a + 1, N):
145
+ j = valid_indices[b]
146
+ # if facedetect_v1_frame_index[j] > total_frame:
147
+ # continue
148
+ angle_j = face_crop_angle[j][0]["angle"]
149
+
150
+ # 判断是否满足阈值
151
+ if (
152
+ abs(angle_i["yaw"] - angle_j["yaw"]) > angle_threshold or
153
+ abs(angle_i["pitch"] - angle_j["pitch"]) > angle_threshold or
154
+ abs(angle_i["roll"] - angle_j["roll"]) > angle_threshold
155
+ ):
156
+ # 找第三个 k
157
+ for c in range(N):
158
+ k = valid_indices[c]
159
+ # if facedetect_v1_frame_index[k] > total_frame:
160
+ # continue
161
+ if k != i and k != j:
162
+ ref_id.append([i, j, k])
163
+ break
164
+
165
+ try_num += 1
166
+ if try_num >= max_try or len(ref_id) >= need_max:
167
+ return ref_id
168
+
169
+ return ref_id
170
+ def crop_and_resize(self, image, target_height, target_width):
171
+ width, height = image.size
172
+ scale = max(target_width / width, target_height / height)
173
+ image = torchvision.transforms.functional.resize(
174
+ image,
175
+ (round(height*scale), round(width*scale)),
176
+ interpolation=torchvision.transforms.InterpolationMode.BILINEAR
177
+ )
178
+ image = torchvision.transforms.functional.center_crop(image, (target_height, target_width))
179
+ return image
180
+
181
+ def get_height_width(self, image):
182
+ if self.dynamic_resolution:
183
+ width, height = image.size
184
+ if width * height > self.max_pixels:
185
+ scale = (width * height / self.max_pixels) ** 0.5
186
+ height, width = int(height / scale), int(width / scale)
187
+ height = height // self.height_division_factor * self.height_division_factor
188
+ width = width // self.width_division_factor * self.width_division_factor
189
+ else:
190
+ height, width = self.height, self.width
191
+ return height, width
192
+
193
+ # def
194
+ # img_ratio = img.width / img.height
195
+ # target_ratio = w / h
196
+ # if img_ratio > target_ratio: # Image is wider than target
197
+ # new_width = w
198
+ # new_height = int(new_width / img_ratio)
199
+ # else: # Image is taller than target
200
+ # new_height = h
201
+ # new_width = int(new_height * img_ratio)
202
+
203
+ # # img = img.resize((new_width, new_height), Image.ANTIALIAS)
204
+ # img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
205
+
206
+ # # Create a new image with the target size and place the resized image in the center
207
+ # delta_w = w - img.size[0]
208
+ # delta_h = h - img.size[1]
209
+ # padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2))
210
+ # new_img = ImageOps.expand(img, padding, fill=(255, 255, 255))
211
+
212
+ def resize_ref(self, img, target_h, target_w):
213
+ h = target_h
214
+ w = target_w
215
+ img = img.convert("RGB")
216
+ # Calculate the required size to keep aspect ratio and fill the rest with padding.
217
+ img_ratio = img.width / img.height
218
+ target_ratio = w / h
219
+
220
+ if img_ratio > target_ratio: # Image is wider than target
221
+ new_width = w
222
+ new_height = int(new_width / img_ratio)
223
+ else: # Image is taller than target
224
+ new_height = h
225
+ new_width = int(new_height * img_ratio)
226
+
227
+ # img = img.resize((new_width, new_height), Image.ANTIALIAS)
228
+ img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
229
+
230
+ # Create a new image with the target size and place the resized image in the center
231
+ delta_w = w - img.size[0]
232
+ delta_h = h - img.size[1]
233
+ padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2))
234
+ new_img = ImageOps.expand(img, padding, fill=(255, 255, 255))
235
+
236
+ return new_img
237
+
238
+
239
+ def load_video_crop_ref_image(self, video_path=None, ref_id_all=[[]]):
240
+ ### fps 转化
241
+ reader = imageio.get_reader(video_path)
242
+ meta = reader.get_meta_data()
243
+ original_fps = meta.get("fps", 24)
244
+ target_fps = 16
245
+ duration_seconds = 5
246
+ target_frames = target_fps * duration_seconds + 1 # = 80 frames
247
+
248
+ # ---- 获取原视频帧数 ----
249
+ try:
250
+ total_original_frames = reader.count_frames()
251
+ except:
252
+ total_original_frames = int(meta.get("duration", 5) * original_fps)
253
+
254
+
255
+
256
+ # ---- 需要多少原始帧(5秒)----
257
+ need_orig_frames = int(original_fps * duration_seconds)
258
+
259
+ # ---- Case 1: 原视频 >= 5秒 → 随机选择 5 秒起点 ----
260
+ if total_original_frames > need_orig_frames:
261
+ max_start = total_original_frames - need_orig_frames
262
+ start_frame = random.randint(0, max_start)
263
+ segment_start = start_frame
264
+ segment_end = start_frame + need_orig_frames
265
+ else:
266
+ # ---- Case 2: 原视频 < 5秒 → 用全部帧 ----
267
+ segment_start = 0
268
+ segment_end = total_original_frames
269
+
270
+ # ---- 均匀采样 80 帧 ----
271
+ sample_ids = np.linspace(segment_start, segment_end - 1, num=target_frames, dtype=int)
272
+
273
+ frames = []
274
+ for frame_id in sample_ids:
275
+ frame = reader.get_data(int(frame_id))
276
+ frame = Image.fromarray(frame)
277
+ frame = self.crop_and_resize(frame, *self.get_height_width(frame))
278
+ frames.append(frame)
279
+
280
+ # ===========================
281
+ # 选择参考图部分(你要求的)
282
+ # ===========================
283
+
284
+ # 1)从 ref_images_all(三维 list)里随机选一组
285
+ # ref_images_all = [ [img1, img2, img3], [imgA, imgB, imgC], ... ]
286
+ ref_group = random.choice(ref_id_all)
287
+
288
+ # 2)检查资源是否足够
289
+ if len(ref_group) < self.ref_num:
290
+ raise ValueError(f"需要 {self.ref_num} 张参考图,但该组只有 {len(ref_group)} 张。")
291
+
292
+ # 3)从该组中随机选 self.ref_num 张
293
+ selected_refs = random.sample(ref_group, self.ref_num)
294
+ random.shuffle(selected_refs)
295
+
296
+ ref_images = []
297
+ for sf in selected_refs:
298
+ top, height, width, left, frame_index = sf
299
+ # import pdb; pdb.set_trace()
300
+ if frame_index > total_original_frames:
301
+ print(f"{video_path}, frame_index({frame_index}) out of range")
302
+ frame = reader.get_data(int(frame_index))
303
+ frame = Image.fromarray(frame)
304
+ xmin, ymin, xmax, ymax = left, top, left + width, top + height
305
+ cropped_image = frame.crop((xmin, ymin, xmax, ymax)).convert("RGB")
306
+ cropped_image = self.resize_ref(cropped_image, self.height, self.width)
307
+ # Calculate the required size to keep aspect ratio and fill the rest with padding.
308
+ ref_images.append(cropped_image)
309
+ reader.close()
310
+
311
+ return frames, ref_images
312
+
313
+ def __getitem__(self, index):
314
+ max_retry = 10 # 最多重试 10 次,避免死循环
315
+ retry = 0
316
+
317
+ while retry < max_retry:
318
+ # ----- 选择 train / test 数据 -----
319
+ if self.training:
320
+ meta_data = self.data_train[index % len(self.data_train)]
321
+ else:
322
+ meta_data = self.data_test[index % len(self.data_test)]
323
+
324
+ video_path = meta_data['video_path']
325
+ meta_prompt = meta_data['meta_prompt']
326
+ ref_id_all = meta_data['ref_id_all']
327
+
328
+ # ----- 尝试读取 video + ref -----
329
+ try:
330
+ input_video, ref_images = self.load_video_crop_ref_image(
331
+ video_path=video_path,
332
+ ref_id_all=ref_id_all
333
+ )
334
+ except Exception as e:
335
+ print("❌ Exception in load_video_crop_ref_image")
336
+ print(f" video_path: {video_path}")
337
+ print(f" error type: {type(e).__name__}")
338
+ print(f" error msg : {e}")
339
+
340
+ # 打印 traceback,定位问题更容易
341
+ import traceback
342
+ traceback.print_exc()
343
+ input_video = None
344
+ ref_images = None
345
+ # ----- 如果成功,并且 video 不为空,返回结果 -----
346
+ if input_video is not None and len(input_video) > 0:
347
+ return {
348
+ "global_caption": None,
349
+ "shot_num": 1,
350
+ "pre_shot_caption": [],
351
+ "single_caption": meta_prompt["single_prompt"],
352
+ "video": input_video,
353
+ "ref_num": self.ref_num,
354
+ "ref_images": ref_images,
355
+ "video_path": video_path
356
+ }
357
+
358
+ # ----- 如果失败,换 index,并继续尝试 -----
359
+ retry += 1
360
+ index = random.randint(0, len(self.data_train) - 1 if self.training else len(self.data_test) - 1)
361
+
362
+ # 若 10 次都失败,返回最后一次的错误内容
363
+ raise RuntimeError(f"❌ [Dataset] Failed to load video/ref after {max_retry} retries.")
364
+
365
+ def __len__(self):
366
+ if self.training:
367
+ return len(self.data_train)
368
+ else:
369
+ return len(self.data_test)
370
+
371
+ if __name__ == '__main__':
372
+ from torch.utils.data import DataLoader
373
+ dataset = MulltiShot_MultiView_Dataset(length=49, resolution=(384, 640), training=True)
374
+ print(len(dataset))
375
+ metadata = dataset[0]
376
+ # results = dataset[0]
377
+ # loader = DataLoader(
378
+ # dataset,
379
+ # batch_size=1, # 视频一般 batch=1
380
+ # shuffle=False, # 你想打乱就 True
381
+ # num_workers=10, # ⭐ 重点:开启 8 个子进程并行加载
382
+ # pin_memory=True,
383
+ # prefetch_factor=2, # 每个 worker 预读取 2 个样本
384
+ # collate_fn=lambda x: x[0], # ⭐ 不做任何 collate
385
+ # )
386
+
387
+ # for batch in tqdm(loader):
388
+ # pass
389
+ for i in tqdm(range(len(dataset))):
390
+ file = dataset[i]
391
+
392
+ assert 0
393
+
diffsynth/diffusion/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .flow_match import FlowMatchScheduler
2
+ from .training_module import DiffusionTrainingModule
3
+ from .logger import ModelLogger
4
+ from .runner import launch_training_task, launch_data_process_task
5
+ from .parsers import *
6
+ from .loss import *
diffsynth/diffusion/base_pipeline.py ADDED
@@ -0,0 +1,439 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import torch
3
+ import numpy as np
4
+ from einops import repeat, reduce
5
+ from typing import Union
6
+ from ..core import AutoTorchModule, AutoWrappedLinear, load_state_dict, ModelConfig
7
+ from ..utils.lora import GeneralLoRALoader
8
+ from ..models.model_loader import ModelPool
9
+ from ..utils.controlnet import ControlNetInput
10
+
11
+
12
+ class PipelineUnit:
13
+ def __init__(
14
+ self,
15
+ seperate_cfg: bool = False,
16
+ take_over: bool = False,
17
+ input_params: tuple[str] = None,
18
+ output_params: tuple[str] = None,
19
+ input_params_posi: dict[str, str] = None,
20
+ input_params_nega: dict[str, str] = None,
21
+ onload_model_names: tuple[str] = None
22
+ ):
23
+ self.seperate_cfg = seperate_cfg
24
+ self.take_over = take_over
25
+ self.input_params = input_params
26
+ self.output_params = output_params
27
+ self.input_params_posi = input_params_posi
28
+ self.input_params_nega = input_params_nega
29
+ self.onload_model_names = onload_model_names
30
+
31
+ def fetch_input_params(self):
32
+ params = []
33
+ if self.input_params is not None:
34
+ for param in self.input_params:
35
+ params.append(param)
36
+ if self.input_params_posi is not None:
37
+ for _, param in self.input_params_posi.items():
38
+ params.append(param)
39
+ if self.input_params_nega is not None:
40
+ for _, param in self.input_params_nega.items():
41
+ params.append(param)
42
+ params = sorted(list(set(params)))
43
+ return params
44
+
45
+ def fetch_output_params(self):
46
+ params = []
47
+ if self.output_params is not None:
48
+ for param in self.output_params:
49
+ params.append(param)
50
+ return params
51
+
52
+ def process(self, pipe, **kwargs) -> dict:
53
+ return {}
54
+
55
+ def post_process(self, pipe, **kwargs) -> dict:
56
+ return {}
57
+
58
+
59
+ class BasePipeline(torch.nn.Module):
60
+
61
+ def __init__(
62
+ self,
63
+ device="cuda", torch_dtype=torch.float16,
64
+ height_division_factor=64, width_division_factor=64,
65
+ time_division_factor=None, time_division_remainder=None,
66
+ ):
67
+ super().__init__()
68
+ # The device and torch_dtype is used for the storage of intermediate variables, not models.
69
+ self.device = device
70
+ self.torch_dtype = torch_dtype
71
+ # The following parameters are used for shape check.
72
+ self.height_division_factor = height_division_factor
73
+ self.width_division_factor = width_division_factor
74
+ self.time_division_factor = time_division_factor
75
+ self.time_division_remainder = time_division_remainder
76
+ # VRAM management
77
+ self.vram_management_enabled = False
78
+ # Pipeline Unit Runner
79
+ self.unit_runner = PipelineUnitRunner()
80
+ # LoRA Loader
81
+ self.lora_loader = GeneralLoRALoader
82
+
83
+
84
+ def to(self, *args, **kwargs):
85
+ device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
86
+ if device is not None:
87
+ self.device = device
88
+ if dtype is not None:
89
+ self.torch_dtype = dtype
90
+ super().to(*args, **kwargs)
91
+ return self
92
+
93
+
94
+ def check_resize_height_width(self, height, width, num_frames=None):
95
+ # Shape check
96
+ if height % self.height_division_factor != 0:
97
+ height = (height + self.height_division_factor - 1) // self.height_division_factor * self.height_division_factor
98
+ print(f"height % {self.height_division_factor} != 0. We round it up to {height}.")
99
+ if width % self.width_division_factor != 0:
100
+ width = (width + self.width_division_factor - 1) // self.width_division_factor * self.width_division_factor
101
+ print(f"width % {self.width_division_factor} != 0. We round it up to {width}.")
102
+ if num_frames is None:
103
+ return height, width
104
+ else:
105
+ if num_frames % self.time_division_factor != self.time_division_remainder:
106
+ num_frames = (num_frames + self.time_division_factor - 1) // self.time_division_factor * self.time_division_factor + self.time_division_remainder
107
+ print(f"num_frames % {self.time_division_factor} != {self.time_division_remainder}. We round it up to {num_frames}.")
108
+ return height, width, num_frames
109
+
110
+
111
+ def preprocess_image(self, image, torch_dtype=None, device=None, pattern="B C H W", min_value=-1, max_value=1):
112
+ # Transform a PIL.Image to torch.Tensor
113
+ image = torch.Tensor(np.array(image, dtype=np.float32))
114
+ image = image.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device)
115
+ image = image * ((max_value - min_value) / 255) + min_value
116
+ image = repeat(image, f"H W C -> {pattern}", **({"B": 1} if "B" in pattern else {}))
117
+ return image
118
+
119
+
120
+ def preprocess_video(self, video, torch_dtype=None, device=None, pattern="B C T H W", min_value=-1, max_value=1):
121
+ # Transform a list of PIL.Image to torch.Tensor
122
+ video = [self.preprocess_image(image, torch_dtype=torch_dtype, device=device, min_value=min_value, max_value=max_value) for image in video]
123
+ video = torch.stack(video, dim=pattern.index("T") // 2)
124
+ return video
125
+
126
+
127
+ def vae_output_to_image(self, vae_output, pattern="B C H W", min_value=-1, max_value=1):
128
+ # Transform a torch.Tensor to PIL.Image
129
+ if pattern != "H W C":
130
+ vae_output = reduce(vae_output, f"{pattern} -> H W C", reduction="mean")
131
+ image = ((vae_output - min_value) * (255 / (max_value - min_value))).clip(0, 255)
132
+ image = image.to(device="cpu", dtype=torch.uint8)
133
+ image = Image.fromarray(image.numpy())
134
+ return image
135
+
136
+
137
+ def vae_output_to_video(self, vae_output, pattern="B C T H W", min_value=-1, max_value=1):
138
+ # Transform a torch.Tensor to list of PIL.Image
139
+ if pattern != "T H W C":
140
+ vae_output = reduce(vae_output, f"{pattern} -> T H W C", reduction="mean")
141
+ video = [self.vae_output_to_image(image, pattern="H W C", min_value=min_value, max_value=max_value) for image in vae_output]
142
+ return video
143
+
144
+
145
+ def load_models_to_device(self, model_names):
146
+ if self.vram_management_enabled:
147
+ # offload models
148
+ for name, model in self.named_children():
149
+ if name not in model_names:
150
+ if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
151
+ if hasattr(model, "offload"):
152
+ model.offload()
153
+ else:
154
+ for module in model.modules():
155
+ if hasattr(module, "offload"):
156
+ module.offload()
157
+ torch.cuda.empty_cache()
158
+ # onload models
159
+ for name, model in self.named_children():
160
+ if name in model_names:
161
+ if hasattr(model, "vram_management_enabled") and model.vram_management_enabled:
162
+ if hasattr(model, "onload"):
163
+ model.onload()
164
+ else:
165
+ for module in model.modules():
166
+ if hasattr(module, "onload"):
167
+ module.onload()
168
+
169
+
170
+ def generate_noise(self, shape, seed=None, rand_device="cpu", rand_torch_dtype=torch.float32, device=None, torch_dtype=None):
171
+ # Initialize Gaussian noise
172
+ generator = None if seed is None else torch.Generator(rand_device).manual_seed(seed)
173
+ noise = torch.randn(shape, generator=generator, device=rand_device, dtype=rand_torch_dtype)
174
+ noise = noise.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device)
175
+ return noise
176
+
177
+
178
+ def get_vram(self):
179
+ return torch.cuda.mem_get_info(self.device)[1] / (1024 ** 3)
180
+
181
+ def get_module(self, model, name):
182
+ if "." in name:
183
+ name, suffix = name[:name.index(".")], name[name.index(".") + 1:]
184
+ if name.isdigit():
185
+ return self.get_module(model[int(name)], suffix)
186
+ else:
187
+ return self.get_module(getattr(model, name), suffix)
188
+ else:
189
+ return getattr(model, name)
190
+
191
+ def freeze_except(self, model_names):
192
+ self.eval()
193
+ self.requires_grad_(False)
194
+ for name in model_names:
195
+ module = self.get_module(self, name)
196
+ if module is None:
197
+ print(f"No {name} models in the pipeline. We cannot enable training on the model. If this occurs during the data processing stage, it is normal.")
198
+ continue
199
+ module.train()
200
+ module.requires_grad_(True)
201
+
202
+
203
+ def blend_with_mask(self, base, addition, mask):
204
+ return base * (1 - mask) + addition * mask
205
+
206
+
207
+ def step(self, scheduler, latents, progress_id, noise_pred, input_latents=None, inpaint_mask=None, **kwargs):
208
+ timestep = scheduler.timesteps[progress_id]
209
+ if inpaint_mask is not None:
210
+ noise_pred_expected = scheduler.return_to_timestep(scheduler.timesteps[progress_id], latents, input_latents)
211
+ noise_pred = self.blend_with_mask(noise_pred_expected, noise_pred, inpaint_mask)
212
+ latents_next = scheduler.step(noise_pred, timestep, latents)
213
+ return latents_next
214
+
215
+
216
+ def split_pipeline_units(self, model_names: list[str]):
217
+ return PipelineUnitGraph().split_pipeline_units(self.units, model_names)
218
+
219
+
220
+ def flush_vram_management_device(self, device):
221
+ for module in self.modules():
222
+ if isinstance(module, AutoTorchModule):
223
+ module.offload_device = device
224
+ module.onload_device = device
225
+ module.preparing_device = device
226
+ module.computation_device = device
227
+
228
+
229
+ def load_lora(
230
+ self,
231
+ module: torch.nn.Module,
232
+ lora_config: Union[ModelConfig, str] = None,
233
+ alpha=1,
234
+ hotload=None,
235
+ state_dict=None,
236
+ ):
237
+ if state_dict is None:
238
+ if isinstance(lora_config, str):
239
+ lora = load_state_dict(lora_config, torch_dtype=self.torch_dtype, device=self.device)
240
+ else:
241
+ lora_config.download_if_necessary()
242
+ lora = load_state_dict(lora_config.path, torch_dtype=self.torch_dtype, device=self.device)
243
+ else:
244
+ lora = state_dict
245
+ lora_loader = self.lora_loader(torch_dtype=self.torch_dtype, device=self.device)
246
+ lora = lora_loader.convert_state_dict(lora)
247
+ if hotload is None:
248
+ hotload = hasattr(module, "vram_management_enabled") and getattr(module, "vram_management_enabled")
249
+ if hotload:
250
+ if not (hasattr(module, "vram_management_enabled") and getattr(module, "vram_management_enabled")):
251
+ raise ValueError("VRAM Management is not enabled. LoRA hotloading is not supported.")
252
+ updated_num = 0
253
+ for _, module in module.named_modules():
254
+ if isinstance(module, AutoWrappedLinear):
255
+ name = module.name
256
+ lora_a_name = f'{name}.lora_A.weight'
257
+ lora_b_name = f'{name}.lora_B.weight'
258
+ if lora_a_name in lora and lora_b_name in lora:
259
+ updated_num += 1
260
+ module.lora_A_weights.append(lora[lora_a_name] * alpha)
261
+ module.lora_B_weights.append(lora[lora_b_name])
262
+ print(f"{updated_num} tensors are patched by LoRA. You can use `pipe.clear_lora()` to clear all LoRA layers.")
263
+ else:
264
+ lora_loader.fuse_lora_to_base_model(module, lora, alpha=alpha)
265
+
266
+
267
+ def clear_lora(self):
268
+ cleared_num = 0
269
+ for name, module in self.named_modules():
270
+ if isinstance(module, AutoWrappedLinear):
271
+ if hasattr(module, "lora_A_weights"):
272
+ if len(module.lora_A_weights) > 0:
273
+ cleared_num += 1
274
+ module.lora_A_weights.clear()
275
+ if hasattr(module, "lora_B_weights"):
276
+ module.lora_B_weights.clear()
277
+ print(f"{cleared_num} LoRA layers are cleared.")
278
+
279
+
280
+ def download_and_load_models(self, model_configs: list[ModelConfig] = [], vram_limit: float = None):
281
+ model_pool = ModelPool()
282
+ for model_config in model_configs:
283
+ model_config.download_if_necessary()
284
+ vram_config = model_config.vram_config()
285
+ vram_config["computation_dtype"] = vram_config["computation_dtype"] or self.torch_dtype
286
+ vram_config["computation_device"] = vram_config["computation_device"] or self.device
287
+ model_pool.auto_load_model(
288
+ model_config.path,
289
+ vram_config=vram_config,
290
+ vram_limit=vram_limit,
291
+ clear_parameters=model_config.clear_parameters,
292
+ )
293
+ return model_pool
294
+
295
+
296
+ def check_vram_management_state(self):
297
+ vram_management_enabled = False
298
+ for module in self.children():
299
+ if hasattr(module, "vram_management_enabled") and getattr(module, "vram_management_enabled"):
300
+ vram_management_enabled = True
301
+ return vram_management_enabled
302
+
303
+
304
+ def cfg_guided_model_fn(self, model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega, **inputs_others):
305
+ noise_pred_posi = model_fn(**inputs_posi, **inputs_shared, **inputs_others)
306
+ if cfg_scale != 1.0:
307
+ noise_pred_nega = model_fn(**inputs_nega, **inputs_shared, **inputs_others)
308
+ noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
309
+ else:
310
+ noise_pred = noise_pred_posi
311
+ return noise_pred
312
+
313
+
314
+ class PipelineUnitGraph:
315
+ def __init__(self):
316
+ pass
317
+
318
+ def build_edges(self, units: list[PipelineUnit]):
319
+ # Establish dependencies between units
320
+ # to search for subsequent related computation units.
321
+ last_compute_unit_id = {}
322
+ edges = []
323
+ for unit_id, unit in enumerate(units):
324
+ for input_param in unit.fetch_input_params():
325
+ if input_param in last_compute_unit_id:
326
+ edges.append((last_compute_unit_id[input_param], unit_id))
327
+ for output_param in unit.fetch_output_params():
328
+ last_compute_unit_id[output_param] = unit_id
329
+ return edges
330
+
331
+ def build_chains(self, units: list[PipelineUnit]):
332
+ # Establish updating chains for each variable
333
+ # to track their computation process.
334
+ params = sum([unit.fetch_input_params() + unit.fetch_output_params() for unit in units], [])
335
+ params = sorted(list(set(params)))
336
+ chains = {param: [] for param in params}
337
+ for unit_id, unit in enumerate(units):
338
+ for param in unit.fetch_output_params():
339
+ chains[param].append(unit_id)
340
+ return chains
341
+
342
+ def search_direct_unit_ids(self, units: list[PipelineUnit], model_names: list[str]):
343
+ # Search for units that directly participate in the model's computation.
344
+ related_unit_ids = []
345
+ for unit_id, unit in enumerate(units):
346
+ for model_name in model_names:
347
+ if unit.onload_model_names is not None and model_name in unit.onload_model_names:
348
+ related_unit_ids.append(unit_id)
349
+ break
350
+ return related_unit_ids
351
+
352
+ def search_related_unit_ids(self, edges, start_unit_ids, direction="target"):
353
+ # Search for subsequent related computation units.
354
+ related_unit_ids = [unit_id for unit_id in start_unit_ids]
355
+ while True:
356
+ neighbors = []
357
+ for source, target in edges:
358
+ if direction == "target" and source in related_unit_ids and target not in related_unit_ids:
359
+ neighbors.append(target)
360
+ elif direction == "source" and source not in related_unit_ids and target in related_unit_ids:
361
+ neighbors.append(source)
362
+ neighbors = sorted(list(set(neighbors)))
363
+ if len(neighbors) == 0:
364
+ break
365
+ else:
366
+ related_unit_ids.extend(neighbors)
367
+ related_unit_ids = sorted(list(set(related_unit_ids)))
368
+ return related_unit_ids
369
+
370
+ def search_updating_unit_ids(self, units: list[PipelineUnit], chains, related_unit_ids):
371
+ # If the input parameters of this subgraph are updated outside the subgraph,
372
+ # search for the units where these updates occur.
373
+ first_compute_unit_id = {}
374
+ for unit_id in related_unit_ids:
375
+ for param in units[unit_id].fetch_input_params():
376
+ if param not in first_compute_unit_id:
377
+ first_compute_unit_id[param] = unit_id
378
+ updating_unit_ids = []
379
+ for param in first_compute_unit_id:
380
+ unit_id = first_compute_unit_id[param]
381
+ chain = chains[param]
382
+ if unit_id in chain and chain.index(unit_id) != len(chain) - 1:
383
+ for unit_id_ in chain[chain.index(unit_id) + 1:]:
384
+ if unit_id_ not in related_unit_ids:
385
+ updating_unit_ids.append(unit_id_)
386
+ related_unit_ids.extend(updating_unit_ids)
387
+ related_unit_ids = sorted(list(set(related_unit_ids)))
388
+ return related_unit_ids
389
+
390
+ def split_pipeline_units(self, units: list[PipelineUnit], model_names: list[str]):
391
+ # Split the computation graph,
392
+ # separating all model-related computations.
393
+ related_unit_ids = self.search_direct_unit_ids(units, model_names)
394
+ edges = self.build_edges(units)
395
+ chains = self.build_chains(units)
396
+ while True:
397
+ num_related_unit_ids = len(related_unit_ids)
398
+ related_unit_ids = self.search_related_unit_ids(edges, related_unit_ids, "target")
399
+ related_unit_ids = self.search_updating_unit_ids(units, chains, related_unit_ids)
400
+ if len(related_unit_ids) == num_related_unit_ids:
401
+ break
402
+ else:
403
+ num_related_unit_ids = len(related_unit_ids)
404
+ related_units = [units[i] for i in related_unit_ids]
405
+ unrelated_units = [units[i] for i in range(len(units)) if i not in related_unit_ids]
406
+ return related_units, unrelated_units
407
+
408
+
409
+ class PipelineUnitRunner:
410
+ def __init__(self):
411
+ pass
412
+
413
+ def __call__(self, unit: PipelineUnit, pipe: BasePipeline, inputs_shared: dict, inputs_posi: dict, inputs_nega: dict) -> tuple[dict, dict]:
414
+ if unit.take_over:
415
+ # Let the pipeline unit take over this function.
416
+ inputs_shared, inputs_posi, inputs_nega = unit.process(pipe, inputs_shared=inputs_shared, inputs_posi=inputs_posi, inputs_nega=inputs_nega)
417
+ elif unit.seperate_cfg:
418
+ # Positive side
419
+ processor_inputs = {name: inputs_posi.get(name_) for name, name_ in unit.input_params_posi.items()}
420
+ if unit.input_params is not None:
421
+ for name in unit.input_params:
422
+ processor_inputs[name] = inputs_shared.get(name)
423
+ processor_outputs = unit.process(pipe, **processor_inputs)
424
+ inputs_posi.update(processor_outputs)
425
+ # Negative side
426
+ if inputs_shared["cfg_scale"] != 1:
427
+ processor_inputs = {name: inputs_nega.get(name_) for name, name_ in unit.input_params_nega.items()}
428
+ if unit.input_params is not None:
429
+ for name in unit.input_params:
430
+ processor_inputs[name] = inputs_shared.get(name)
431
+ processor_outputs = unit.process(pipe, **processor_inputs)
432
+ inputs_nega.update(processor_outputs)
433
+ else:
434
+ inputs_nega.update(processor_outputs)
435
+ else:
436
+ processor_inputs = {name: inputs_shared.get(name) for name in unit.input_params}
437
+ processor_outputs = unit.process(pipe, **processor_inputs)
438
+ inputs_shared.update(processor_outputs)
439
+ return inputs_shared, inputs_posi, inputs_nega
diffsynth/diffusion/flow_match.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, math
2
+ from typing_extensions import Literal
3
+
4
+
5
+ class FlowMatchScheduler():
6
+
7
+ def __init__(self, template: Literal["FLUX.1", "Wan", "Qwen-Image", "FLUX.2", "Z-Image"] = "FLUX.1"):
8
+ self.set_timesteps_fn = {
9
+ "FLUX.1": FlowMatchScheduler.set_timesteps_flux,
10
+ "Wan": FlowMatchScheduler.set_timesteps_wan,
11
+ "Qwen-Image": FlowMatchScheduler.set_timesteps_qwen_image,
12
+ "FLUX.2": FlowMatchScheduler.set_timesteps_flux2,
13
+ "Z-Image": FlowMatchScheduler.set_timesteps_z_image,
14
+ }.get(template, FlowMatchScheduler.set_timesteps_flux)
15
+ self.num_train_timesteps = 1000
16
+
17
+ @staticmethod
18
+ def set_timesteps_flux(num_inference_steps=100, denoising_strength=1.0, shift=None):
19
+ sigma_min = 0.003/1.002
20
+ sigma_max = 1.0
21
+ shift = 3 if shift is None else shift
22
+ num_train_timesteps = 1000
23
+ sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
24
+ sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps)
25
+ sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
26
+ timesteps = sigmas * num_train_timesteps
27
+ return sigmas, timesteps
28
+
29
+ @staticmethod
30
+ def set_timesteps_wan(num_inference_steps=100, denoising_strength=1.0, shift=None):
31
+ sigma_min = 0.0
32
+ sigma_max = 1.0
33
+ shift = 5 if shift is None else shift
34
+ num_train_timesteps = 1000
35
+ sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
36
+ sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
37
+ sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
38
+ timesteps = sigmas * num_train_timesteps
39
+ return sigmas, timesteps
40
+
41
+ @staticmethod
42
+ def _calculate_shift_qwen_image(image_seq_len, base_seq_len=256, max_seq_len=8192, base_shift=0.5, max_shift=0.9):
43
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
44
+ b = base_shift - m * base_seq_len
45
+ mu = image_seq_len * m + b
46
+ return mu
47
+
48
+ @staticmethod
49
+ def set_timesteps_qwen_image(num_inference_steps=100, denoising_strength=1.0, exponential_shift_mu=None, dynamic_shift_len=None):
50
+ sigma_min = 0.0
51
+ sigma_max = 1.0
52
+ num_train_timesteps = 1000
53
+ shift_terminal = 0.02
54
+ # Sigmas
55
+ sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
56
+ sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
57
+ # Mu
58
+ if exponential_shift_mu is not None:
59
+ mu = exponential_shift_mu
60
+ elif dynamic_shift_len is not None:
61
+ mu = FlowMatchScheduler._calculate_shift_qwen_image(dynamic_shift_len)
62
+ else:
63
+ mu = 0.8
64
+ sigmas = math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1))
65
+ # Shift terminal
66
+ one_minus_z = 1 - sigmas
67
+ scale_factor = one_minus_z[-1] / (1 - shift_terminal)
68
+ sigmas = 1 - (one_minus_z / scale_factor)
69
+ # Timesteps
70
+ timesteps = sigmas * num_train_timesteps
71
+ return sigmas, timesteps
72
+
73
+ @staticmethod
74
+ def compute_empirical_mu(image_seq_len, num_steps):
75
+ a1, b1 = 8.73809524e-05, 1.89833333
76
+ a2, b2 = 0.00016927, 0.45666666
77
+
78
+ if image_seq_len > 4300:
79
+ mu = a2 * image_seq_len + b2
80
+ return float(mu)
81
+
82
+ m_200 = a2 * image_seq_len + b2
83
+ m_10 = a1 * image_seq_len + b1
84
+
85
+ a = (m_200 - m_10) / 190.0
86
+ b = m_200 - 200.0 * a
87
+ mu = a * num_steps + b
88
+
89
+ return float(mu)
90
+
91
+ @staticmethod
92
+ def set_timesteps_flux2(num_inference_steps=100, denoising_strength=1.0, dynamic_shift_len=1024//16*1024//16):
93
+ sigma_min = 1 / num_inference_steps
94
+ sigma_max = 1.0
95
+ num_train_timesteps = 1000
96
+ sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
97
+ sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps)
98
+ mu = FlowMatchScheduler.compute_empirical_mu(dynamic_shift_len, num_inference_steps)
99
+ sigmas = math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1))
100
+ timesteps = sigmas * num_train_timesteps
101
+ return sigmas, timesteps
102
+
103
+ @staticmethod
104
+ def set_timesteps_z_image(num_inference_steps=100, denoising_strength=1.0, shift=None, target_timesteps=None):
105
+ sigma_min = 0.0
106
+ sigma_max = 1.0
107
+ shift = 3 if shift is None else shift
108
+ num_train_timesteps = 1000
109
+ sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
110
+ sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
111
+ sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
112
+ timesteps = sigmas * num_train_timesteps
113
+ if target_timesteps is not None:
114
+ target_timesteps = target_timesteps.to(dtype=timesteps.dtype, device=timesteps.device)
115
+ for timestep in target_timesteps:
116
+ timestep_id = torch.argmin((timesteps - timestep).abs())
117
+ timesteps[timestep_id] = timestep
118
+ return sigmas, timesteps
119
+
120
+ def set_training_weight(self):
121
+ steps = 1000
122
+ x = self.timesteps
123
+ y = torch.exp(-2 * ((x - steps / 2) / steps) ** 2)
124
+ y_shifted = y - y.min()
125
+ bsmntw_weighing = y_shifted * (steps / y_shifted.sum())
126
+ if len(self.timesteps) != 1000:
127
+ # This is an empirical formula.
128
+ bsmntw_weighing = bsmntw_weighing * (len(self.timesteps) / steps)
129
+ bsmntw_weighing = bsmntw_weighing + bsmntw_weighing[1]
130
+ self.linear_timesteps_weights = bsmntw_weighing
131
+
132
+ def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, **kwargs):
133
+ self.sigmas, self.timesteps = self.set_timesteps_fn(
134
+ num_inference_steps=num_inference_steps,
135
+ denoising_strength=denoising_strength,
136
+ **kwargs,
137
+ )
138
+ if training:
139
+ self.set_training_weight()
140
+ self.training = True
141
+ else:
142
+ self.training = False
143
+
144
+ def step(self, model_output, timestep, sample, to_final=False, **kwargs):
145
+ if isinstance(timestep, torch.Tensor):
146
+ timestep = timestep.cpu()
147
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
148
+ sigma = self.sigmas[timestep_id]
149
+ if to_final or timestep_id + 1 >= len(self.timesteps):
150
+ sigma_ = 0
151
+ else:
152
+ sigma_ = self.sigmas[timestep_id + 1]
153
+ prev_sample = sample + model_output * (sigma_ - sigma)
154
+ return prev_sample
155
+
156
+ def return_to_timestep(self, timestep, sample, sample_stablized):
157
+ if isinstance(timestep, torch.Tensor):
158
+ timestep = timestep.cpu()
159
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
160
+ sigma = self.sigmas[timestep_id]
161
+ model_output = (sample - sample_stablized) / sigma
162
+ return model_output
163
+
164
+ def add_noise(self, original_samples, noise, timestep):
165
+ if isinstance(timestep, torch.Tensor):
166
+ timestep = timestep.cpu()
167
+ timestep_id = torch.argmin((self.timesteps - timestep).abs())
168
+ sigma = self.sigmas[timestep_id]
169
+ sample = (1 - sigma) * original_samples + sigma * noise
170
+ return sample
171
+
172
+ def training_target(self, sample, noise, timestep):
173
+ target = noise - sample
174
+ return target
175
+
176
+ def training_weight(self, timestep):
177
+ timestep_id = torch.argmin((self.timesteps - timestep.to(self.timesteps.device)).abs())
178
+ weights = self.linear_timesteps_weights[timestep_id]
179
+ return weights
diffsynth/diffusion/logger.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, torch
2
+ from accelerate import Accelerator
3
+
4
+
5
+ class ModelLogger:
6
+ def __init__(self, output_path, remove_prefix_in_ckpt=None, state_dict_converter=lambda x:x):
7
+ self.output_path = output_path
8
+ self.remove_prefix_in_ckpt = remove_prefix_in_ckpt
9
+ self.state_dict_converter = state_dict_converter
10
+ self.num_steps = 0
11
+
12
+
13
+ def on_step_end(self, accelerator: Accelerator, model: torch.nn.Module, save_steps=None):
14
+ self.num_steps += 1
15
+ if save_steps is not None and self.num_steps % save_steps == 0:
16
+ self.save_model(accelerator, model, f"step-{self.num_steps}.safetensors")
17
+
18
+
19
+ def on_epoch_end(self, accelerator: Accelerator, model: torch.nn.Module, epoch_id):
20
+ accelerator.wait_for_everyone()
21
+ if accelerator.is_main_process:
22
+ state_dict = accelerator.get_state_dict(model)
23
+ state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt)
24
+ state_dict = self.state_dict_converter(state_dict)
25
+ os.makedirs(self.output_path, exist_ok=True)
26
+ path = os.path.join(self.output_path, f"epoch-{epoch_id}.safetensors")
27
+ accelerator.save(state_dict, path, safe_serialization=True)
28
+
29
+
30
+ def on_training_end(self, accelerator: Accelerator, model: torch.nn.Module, save_steps=None):
31
+ if save_steps is not None and self.num_steps % save_steps != 0:
32
+ self.save_model(accelerator, model, f"step-{self.num_steps}.safetensors")
33
+
34
+
35
+ def save_model(self, accelerator: Accelerator, model: torch.nn.Module, file_name):
36
+ accelerator.wait_for_everyone()
37
+ if accelerator.is_main_process:
38
+ state_dict = accelerator.get_state_dict(model)
39
+ state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt)
40
+ state_dict = self.state_dict_converter(state_dict)
41
+ os.makedirs(self.output_path, exist_ok=True)
42
+ path = os.path.join(self.output_path, file_name)
43
+ accelerator.save(state_dict, path, safe_serialization=True)
diffsynth/diffusion/loss.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base_pipeline import BasePipeline
2
+ import torch
3
+
4
+
5
+ def FlowMatchSFTLoss(pipe: BasePipeline, **inputs):
6
+ max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * len(pipe.scheduler.timesteps))
7
+ min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * len(pipe.scheduler.timesteps))
8
+
9
+ timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,))
10
+ timestep = pipe.scheduler.timesteps[timestep_id].to(dtype=pipe.torch_dtype, device=pipe.device)
11
+
12
+ noise = torch.randn_like(inputs["input_latents"])
13
+ inputs["latents"] = pipe.scheduler.add_noise(inputs["input_latents"], noise, timestep)
14
+ training_target = pipe.scheduler.training_target(inputs["input_latents"], noise, timestep)
15
+
16
+ models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
17
+ noise_pred = pipe.model_fn(**models, **inputs, timestep=timestep)
18
+
19
+ loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
20
+ loss = loss * pipe.scheduler.training_weight(timestep)
21
+ return loss
22
+
23
+
24
+ def DirectDistillLoss(pipe: BasePipeline, **inputs):
25
+ pipe.scheduler.set_timesteps(inputs["num_inference_steps"])
26
+ pipe.scheduler.training = True
27
+ models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
28
+ for progress_id, timestep in enumerate(pipe.scheduler.timesteps):
29
+ timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device)
30
+ noise_pred = pipe.model_fn(**models, **inputs, timestep=timestep, progress_id=progress_id)
31
+ inputs["latents"] = pipe.step(pipe.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs)
32
+ loss = torch.nn.functional.mse_loss(inputs["latents"].float(), inputs["input_latents"].float())
33
+ return loss
34
+
35
+
36
+ class TrajectoryImitationLoss(torch.nn.Module):
37
+ def __init__(self):
38
+ super().__init__()
39
+ self.initialized = False
40
+
41
+ def initialize(self, device):
42
+ import lpips # TODO: remove it
43
+ self.loss_fn = lpips.LPIPS(net='alex').to(device)
44
+ self.initialized = True
45
+
46
+ def fetch_trajectory(self, pipe: BasePipeline, timesteps_student, inputs_shared, inputs_posi, inputs_nega, num_inference_steps, cfg_scale):
47
+ trajectory = [inputs_shared["latents"].clone()]
48
+
49
+ pipe.scheduler.set_timesteps(num_inference_steps, target_timesteps=timesteps_student)
50
+ models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
51
+ for progress_id, timestep in enumerate(pipe.scheduler.timesteps):
52
+ timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device)
53
+ noise_pred = pipe.cfg_guided_model_fn(
54
+ pipe.model_fn, cfg_scale,
55
+ inputs_shared, inputs_posi, inputs_nega,
56
+ **models, timestep=timestep, progress_id=progress_id
57
+ )
58
+ inputs_shared["latents"] = pipe.step(pipe.scheduler, progress_id=progress_id, noise_pred=noise_pred.detach(), **inputs_shared)
59
+
60
+ trajectory.append(inputs_shared["latents"].clone())
61
+ return pipe.scheduler.timesteps, trajectory
62
+
63
+ def align_trajectory(self, pipe: BasePipeline, timesteps_teacher, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, num_inference_steps, cfg_scale):
64
+ loss = 0
65
+ pipe.scheduler.set_timesteps(num_inference_steps, training=True)
66
+ models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
67
+ for progress_id, timestep in enumerate(pipe.scheduler.timesteps):
68
+ timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device)
69
+
70
+ progress_id_teacher = torch.argmin((timesteps_teacher - timestep).abs())
71
+ inputs_shared["latents"] = trajectory_teacher[progress_id_teacher]
72
+
73
+ noise_pred = pipe.cfg_guided_model_fn(
74
+ pipe.model_fn, cfg_scale,
75
+ inputs_shared, inputs_posi, inputs_nega,
76
+ **models, timestep=timestep, progress_id=progress_id
77
+ )
78
+
79
+ sigma = pipe.scheduler.sigmas[progress_id]
80
+ sigma_ = 0 if progress_id + 1 >= len(pipe.scheduler.timesteps) else pipe.scheduler.sigmas[progress_id + 1]
81
+ if progress_id + 1 >= len(pipe.scheduler.timesteps):
82
+ latents_ = trajectory_teacher[-1]
83
+ else:
84
+ progress_id_teacher = torch.argmin((timesteps_teacher - pipe.scheduler.timesteps[progress_id + 1]).abs())
85
+ latents_ = trajectory_teacher[progress_id_teacher]
86
+
87
+ target = (latents_ - inputs_shared["latents"]) / (sigma_ - sigma)
88
+ loss = loss + torch.nn.functional.mse_loss(noise_pred.float(), target.float()) * pipe.scheduler.training_weight(timestep)
89
+ return loss
90
+
91
+ def compute_regularization(self, pipe: BasePipeline, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, num_inference_steps, cfg_scale):
92
+ inputs_shared["latents"] = trajectory_teacher[0]
93
+ pipe.scheduler.set_timesteps(num_inference_steps)
94
+ models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
95
+ for progress_id, timestep in enumerate(pipe.scheduler.timesteps):
96
+ timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device)
97
+ noise_pred = pipe.cfg_guided_model_fn(
98
+ pipe.model_fn, cfg_scale,
99
+ inputs_shared, inputs_posi, inputs_nega,
100
+ **models, timestep=timestep, progress_id=progress_id
101
+ )
102
+ inputs_shared["latents"] = pipe.step(pipe.scheduler, progress_id=progress_id, noise_pred=noise_pred.detach(), **inputs_shared)
103
+
104
+ image_pred = pipe.vae_decoder(inputs_shared["latents"])
105
+ image_real = pipe.vae_decoder(trajectory_teacher[-1])
106
+ loss = self.loss_fn(image_pred.float(), image_real.float())
107
+ return loss
108
+
109
+ def forward(self, pipe: BasePipeline, inputs_shared, inputs_posi, inputs_nega):
110
+ if not self.initialized:
111
+ self.initialize(pipe.device)
112
+ with torch.no_grad():
113
+ pipe.scheduler.set_timesteps(8)
114
+ timesteps_teacher, trajectory_teacher = self.fetch_trajectory(inputs_shared["teacher"], pipe.scheduler.timesteps, inputs_shared, inputs_posi, inputs_nega, 50, 2)
115
+ timesteps_teacher = timesteps_teacher.to(dtype=pipe.torch_dtype, device=pipe.device)
116
+ loss_1 = self.align_trajectory(pipe, timesteps_teacher, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, 8, 1)
117
+ loss_2 = self.compute_regularization(pipe, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, 8, 1)
118
+ loss = loss_1 + loss_2
119
+ return loss
diffsynth/diffusion/parsers.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+
4
+ def add_dataset_base_config(parser: argparse.ArgumentParser):
5
+ parser.add_argument("--dataset_base_path", type=str, default="", required=True, help="Base path of the dataset.")
6
+ parser.add_argument("--dataset_metadata_path", type=str, default=None, help="Path to the metadata file of the dataset.")
7
+ parser.add_argument("--dataset_repeat", type=int, default=1, help="Number of times to repeat the dataset per epoch.")
8
+ parser.add_argument("--dataset_num_workers", type=int, default=0, help="Number of workers for data loading.")
9
+ parser.add_argument("--data_file_keys", type=str, default="image,video", help="Data file keys in the metadata. Comma-separated.")
10
+ return parser
11
+
12
+ def add_image_size_config(parser: argparse.ArgumentParser):
13
+ parser.add_argument("--height", type=int, default=None, help="Height of images. Leave `height` and `width` empty to enable dynamic resolution.")
14
+ parser.add_argument("--width", type=int, default=None, help="Width of images. Leave `height` and `width` empty to enable dynamic resolution.")
15
+ parser.add_argument("--max_pixels", type=int, default=1024*1024, help="Maximum number of pixels per frame, used for dynamic resolution.")
16
+ return parser
17
+
18
+ def add_video_size_config(parser: argparse.ArgumentParser):
19
+ parser.add_argument("--height", type=int, default=None, help="Height of images. Leave `height` and `width` empty to enable dynamic resolution.")
20
+ parser.add_argument("--width", type=int, default=None, help="Width of images. Leave `height` and `width` empty to enable dynamic resolution.")
21
+ parser.add_argument("--max_pixels", type=int, default=1024*1024, help="Maximum number of pixels per frame, used for dynamic resolution.")
22
+ parser.add_argument("--num_frames", type=int, default=81, help="Number of frames per video. Frames are sampled from the video prefix.")
23
+ return parser
24
+
25
+ def add_model_config(parser: argparse.ArgumentParser):
26
+ parser.add_argument("--model_paths", type=str, default=None, help="Paths to load models. In JSON format.")
27
+ parser.add_argument("--model_id_with_origin_paths", type=str, default=None, help="Model ID with origin paths, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Comma-separated.")
28
+ parser.add_argument("--extra_inputs", default=None, help="Additional model inputs, comma-separated.")
29
+ parser.add_argument("--fp8_models", default=None, help="Models with FP8 precision, comma-separated.")
30
+ parser.add_argument("--offload_models", default=None, help="Models with offload, comma-separated. Only used in splited training.")
31
+ return parser
32
+
33
+ def add_training_config(parser: argparse.ArgumentParser):
34
+ parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate.")
35
+ parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs.")
36
+ parser.add_argument("--trainable_models", type=str, default=None, help="Models to train, e.g., dit, vae, text_encoder.")
37
+ parser.add_argument("--find_unused_parameters", default=False, action="store_true", help="Whether to find unused parameters in DDP.")
38
+ parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay.")
39
+ parser.add_argument("--task", type=str, default="sft", required=False, help="Task type.")
40
+ return parser
41
+
42
+ def add_output_config(parser: argparse.ArgumentParser):
43
+ parser.add_argument("--output_path", type=str, default="./models", help="Output save path.")
44
+ parser.add_argument("--remove_prefix_in_ckpt", type=str, default="pipe.dit.", help="Remove prefix in ckpt.")
45
+ parser.add_argument("--save_steps", type=int, default=None, help="Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.")
46
+ return parser
47
+
48
+ def add_lora_config(parser: argparse.ArgumentParser):
49
+ parser.add_argument("--lora_base_model", type=str, default=None, help="Which model LoRA is added to.")
50
+ parser.add_argument("--lora_target_modules", type=str, default="q,k,v,o,ffn.0,ffn.2", help="Which layers LoRA is added to.")
51
+ parser.add_argument("--lora_rank", type=int, default=32, help="Rank of LoRA.")
52
+ parser.add_argument("--lora_checkpoint", type=str, default=None, help="Path to the LoRA checkpoint. If provided, LoRA will be loaded from this checkpoint.")
53
+ parser.add_argument("--preset_lora_path", type=str, default=None, help="Path to the preset LoRA checkpoint. If provided, this LoRA will be fused to the base model.")
54
+ parser.add_argument("--preset_lora_model", type=str, default=None, help="Which model the preset LoRA is fused to.")
55
+ return parser
56
+
57
+ def add_gradient_config(parser: argparse.ArgumentParser):
58
+ parser.add_argument("--use_gradient_checkpointing", default=False, action="store_true", help="Whether to use gradient checkpointing.")
59
+ parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.")
60
+ parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.")
61
+ return parser
62
+
63
+ def add_general_config(parser: argparse.ArgumentParser):
64
+ parser = add_dataset_base_config(parser)
65
+ parser = add_model_config(parser)
66
+ parser = add_training_config(parser)
67
+ parser = add_output_config(parser)
68
+ parser = add_lora_config(parser)
69
+ parser = add_gradient_config(parser)
70
+ return parser
diffsynth/diffusion/runner.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, torch
2
+ import wandb # 新增
3
+ from tqdm import tqdm
4
+ from accelerate import Accelerator
5
+ from .training_module import DiffusionTrainingModule
6
+ from .logger import ModelLogger
7
+
8
+
9
+ def launch_training_task(
10
+ accelerator: Accelerator,
11
+ dataset: torch.utils.data.Dataset,
12
+ model: DiffusionTrainingModule,
13
+ model_logger: ModelLogger,
14
+ learning_rate: float = 1e-5,
15
+ weight_decay: float = 1e-2,
16
+ num_workers: int = 1,
17
+ save_steps: int = None,
18
+ num_epochs: int = 1,
19
+ args = None,
20
+ ):
21
+ if args is not None:
22
+ learning_rate = args.learning_rate
23
+ weight_decay = args.weight_decay
24
+ num_workers = args.dataset_num_workers
25
+ save_steps = args.save_steps
26
+ num_epochs = args.num_epochs
27
+
28
+ optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate, weight_decay=weight_decay)
29
+ scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
30
+ dataloader = torch.utils.data.DataLoader(
31
+ dataset,
32
+ shuffle=True,
33
+ collate_fn=lambda x: x[0],
34
+ num_workers=num_workers,
35
+ )
36
+
37
+ model, optimizer, dataloader, scheduler = accelerator.prepare(
38
+ model, optimizer, dataloader, scheduler
39
+ )
40
+
41
+ global_step = 0 # 用于 wandb 记录全局 step
42
+
43
+ for epoch_id in range(num_epochs):
44
+ # 只在本地主进程显示 tqdm,避免多卡重复进度条
45
+ pbar = tqdm(
46
+ dataloader,
47
+ disable=not accelerator.is_local_main_process,
48
+ desc=f"Epoch {epoch_id}",
49
+ )
50
+ for data in pbar:
51
+ with accelerator.accumulate(model):
52
+ optimizer.zero_grad()
53
+ if dataset.load_from_cache:
54
+ loss = model({}, inputs=data)
55
+ else:
56
+ loss = model(data)
57
+ accelerator.backward(loss)
58
+ optimizer.step()
59
+ model_logger.on_step_end(accelerator, model, save_steps)
60
+ scheduler.step()
61
+
62
+ global_step += 1
63
+
64
+ # ============= wandb logging(只在主进程) =============
65
+ if (
66
+ args is not None
67
+ and hasattr(args, "wandb_mode")
68
+ and args.wandb_mode != "disabled"
69
+ and accelerator.is_main_process
70
+ ):
71
+ log_every = getattr(args, "wandb_log_every", 10)
72
+ if global_step % log_every == 0:
73
+ # 这里直接用当前进程的 loss 就够了
74
+ loss_value = loss.detach().float().item()
75
+ try:
76
+ lr = scheduler.get_last_lr()[0]
77
+ except Exception:
78
+ lr = optimizer.param_groups[0]["lr"]
79
+
80
+ wandb.log(
81
+ {
82
+ "train/loss": loss_value,
83
+ "train/lr": lr,
84
+ "train/epoch": epoch_id,
85
+ "train/step": global_step,
86
+ }
87
+ )
88
+ # =======================================================
89
+
90
+ if save_steps is None:
91
+ model_logger.on_epoch_end(accelerator, model, epoch_id)
92
+ model_logger.on_training_end(accelerator, model, save_steps)
93
+
94
+
95
+ def launch_data_process_task(
96
+ accelerator: Accelerator,
97
+ dataset: torch.utils.data.Dataset,
98
+ model: DiffusionTrainingModule,
99
+ model_logger: ModelLogger,
100
+ num_workers: int = 8,
101
+ args = None,
102
+ ):
103
+ if args is not None:
104
+ num_workers = args.dataset_num_workers
105
+
106
+ dataloader = torch.utils.data.DataLoader(
107
+ dataset,
108
+ shuffle=False,
109
+ collate_fn=lambda x: x[0],
110
+ num_workers=num_workers,
111
+ )
112
+ model, dataloader = accelerator.prepare(model, dataloader)
113
+
114
+ for data_id, data in enumerate(tqdm(
115
+ dataloader,
116
+ disable=not accelerator.is_local_main_process,
117
+ desc="Data process",
118
+ )):
119
+ with accelerator.accumulate(model):
120
+ with torch.no_grad():
121
+ folder = os.path.join(model_logger.output_path, str(accelerator.process_index))
122
+ os.makedirs(folder, exist_ok=True)
123
+ save_path = os.path.join(
124
+ model_logger.output_path,
125
+ str(accelerator.process_index),
126
+ f"{data_id}.pth",
127
+ )
128
+ data = model(data)
129
+ torch.save(data, save_path)
diffsynth/diffusion/training_module.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, json
2
+ from ..core import ModelConfig, load_state_dict
3
+ from ..utils.controlnet import ControlNetInput
4
+ from peft import LoraConfig, inject_adapter_in_model
5
+
6
+
7
+ class DiffusionTrainingModule(torch.nn.Module):
8
+ def __init__(self):
9
+ super().__init__()
10
+
11
+
12
+ def to(self, *args, **kwargs):
13
+ for name, model in self.named_children():
14
+ model.to(*args, **kwargs)
15
+ return self
16
+
17
+
18
+ def trainable_modules(self):
19
+ trainable_modules = filter(lambda p: p.requires_grad, self.parameters())
20
+ return trainable_modules
21
+
22
+
23
+ def trainable_param_names(self):
24
+ trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.named_parameters()))
25
+ trainable_param_names = set([named_param[0] for named_param in trainable_param_names])
26
+ return trainable_param_names
27
+
28
+
29
+ def add_lora_to_model(self, model, target_modules, lora_rank, lora_alpha=None, upcast_dtype=None):
30
+ if lora_alpha is None:
31
+ lora_alpha = lora_rank
32
+ if isinstance(target_modules, list) and len(target_modules) == 1:
33
+ target_modules = target_modules[0]
34
+ lora_config = LoraConfig(r=lora_rank, lora_alpha=lora_alpha, target_modules=target_modules)
35
+ model = inject_adapter_in_model(lora_config, model)
36
+ if upcast_dtype is not None:
37
+ for param in model.parameters():
38
+ if param.requires_grad:
39
+ param.data = param.to(upcast_dtype)
40
+ return model
41
+
42
+
43
+ def mapping_lora_state_dict(self, state_dict):
44
+ new_state_dict = {}
45
+ for key, value in state_dict.items():
46
+ if "lora_A.weight" in key or "lora_B.weight" in key:
47
+ new_key = key.replace("lora_A.weight", "lora_A.default.weight").replace("lora_B.weight", "lora_B.default.weight")
48
+ new_state_dict[new_key] = value
49
+ elif "lora_A.default.weight" in key or "lora_B.default.weight" in key:
50
+ new_state_dict[key] = value
51
+ return new_state_dict
52
+
53
+
54
+ def export_trainable_state_dict(self, state_dict, remove_prefix=None):
55
+ trainable_param_names = self.trainable_param_names()
56
+ state_dict = {name: param for name, param in state_dict.items() if name in trainable_param_names}
57
+ if remove_prefix is not None:
58
+ state_dict_ = {}
59
+ for name, param in state_dict.items():
60
+ if name.startswith(remove_prefix):
61
+ name = name[len(remove_prefix):]
62
+ state_dict_[name] = param
63
+ state_dict = state_dict_
64
+ return state_dict
65
+
66
+
67
+ def transfer_data_to_device(self, data, device, torch_float_dtype=None):
68
+ if data is None:
69
+ return data
70
+ elif isinstance(data, torch.Tensor):
71
+ data = data.to(device)
72
+ if torch_float_dtype is not None and data.dtype in [torch.float, torch.float16, torch.bfloat16]:
73
+ data = data.to(torch_float_dtype)
74
+ return data
75
+ elif isinstance(data, tuple):
76
+ data = tuple(self.transfer_data_to_device(x, device, torch_float_dtype) for x in data)
77
+ return data
78
+ elif isinstance(data, list):
79
+ data = list(self.transfer_data_to_device(x, device, torch_float_dtype) for x in data)
80
+ return data
81
+ elif isinstance(data, dict):
82
+ data = {i: self.transfer_data_to_device(data[i], device, torch_float_dtype) for i in data}
83
+ return data
84
+ else:
85
+ return data
86
+
87
+ def parse_vram_config(self, fp8=False, offload=False, device="cpu"):
88
+ if fp8:
89
+ return {
90
+ "offload_dtype": torch.float8_e4m3fn,
91
+ "offload_device": device,
92
+ "onload_dtype": torch.float8_e4m3fn,
93
+ "onload_device": device,
94
+ "preparing_dtype": torch.float8_e4m3fn,
95
+ "preparing_device": device,
96
+ "computation_dtype": torch.bfloat16,
97
+ "computation_device": device,
98
+ }
99
+ elif offload:
100
+ return {
101
+ "offload_dtype": "disk",
102
+ "offload_device": "disk",
103
+ "onload_dtype": "disk",
104
+ "onload_device": "disk",
105
+ "preparing_dtype": torch.bfloat16,
106
+ "preparing_device": device,
107
+ "computation_dtype": torch.bfloat16,
108
+ "computation_device": device,
109
+ "clear_parameters": True,
110
+ }
111
+ else:
112
+ return {}
113
+
114
+ def parse_model_configs(self, model_paths, model_id_with_origin_paths, fp8_models=None, offload_models=None, device="cpu"):
115
+ fp8_models = [] if fp8_models is None else fp8_models.split(",")
116
+ offload_models = [] if offload_models is None else offload_models.split(",")
117
+ model_configs = []
118
+ if model_paths is not None:
119
+ model_paths = json.loads(model_paths)
120
+ for path in model_paths:
121
+ vram_config = self.parse_vram_config(
122
+ fp8=path in fp8_models,
123
+ offload=path in offload_models,
124
+ device=device
125
+ )
126
+ model_configs.append(ModelConfig(path=path, **vram_config))
127
+ if model_id_with_origin_paths is not None:
128
+ model_id_with_origin_paths = model_id_with_origin_paths.split(",")
129
+ for model_id_with_origin_path in model_id_with_origin_paths:
130
+ model_id, origin_file_pattern = model_id_with_origin_path.split(":")
131
+ vram_config = self.parse_vram_config(
132
+ fp8=model_id_with_origin_path in fp8_models,
133
+ offload=model_id_with_origin_path in offload_models,
134
+ device=device
135
+ )
136
+ model_configs.append(ModelConfig(model_id=model_id, origin_file_pattern=origin_file_pattern, **vram_config))
137
+ return model_configs
138
+
139
+
140
+ def switch_pipe_to_training_mode(
141
+ self,
142
+ pipe,
143
+ trainable_models=None,
144
+ lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None,
145
+ preset_lora_path=None, preset_lora_model=None,
146
+ task="sft",
147
+ ):
148
+ # Scheduler
149
+ pipe.scheduler.set_timesteps(1000, training=True)
150
+
151
+ # Freeze untrainable models
152
+ pipe.freeze_except([] if trainable_models is None else trainable_models.split(","))
153
+
154
+ # Preset LoRA
155
+ if preset_lora_path is not None:
156
+ pipe.load_lora(getattr(pipe, preset_lora_model), preset_lora_path)
157
+
158
+ # FP8
159
+ # FP8 relies on a model-specific memory management scheme.
160
+ # It is delegated to the subclass.
161
+
162
+ # Add LoRA to the base models
163
+ if lora_base_model is not None and not task.endswith(":data_process"):
164
+ if (not hasattr(pipe, lora_base_model)) or getattr(pipe, lora_base_model) is None:
165
+ print(f"No {lora_base_model} models in the pipeline. We cannot patch LoRA on the model. If this occurs during the data processing stage, it is normal.")
166
+ return
167
+ model = self.add_lora_to_model(
168
+ getattr(pipe, lora_base_model),
169
+ target_modules=lora_target_modules.split(","),
170
+ lora_rank=lora_rank,
171
+ upcast_dtype=pipe.torch_dtype,
172
+ )
173
+ if lora_checkpoint is not None:
174
+ state_dict = load_state_dict(lora_checkpoint)
175
+ state_dict = self.mapping_lora_state_dict(state_dict)
176
+ load_result = model.load_state_dict(state_dict, strict=False)
177
+ print(f"LoRA checkpoint loaded: {lora_checkpoint}, total {len(state_dict)} keys")
178
+ if len(load_result[1]) > 0:
179
+ print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}")
180
+ setattr(pipe, lora_base_model, model)
181
+
182
+
183
+ def split_pipeline_units(self, task, pipe, trainable_models=None, lora_base_model=None):
184
+ models_require_backward = []
185
+ if trainable_models is not None:
186
+ models_require_backward += trainable_models.split(",")
187
+ if lora_base_model is not None:
188
+ models_require_backward += [lora_base_model]
189
+ if task.endswith(":data_process"):
190
+ _, pipe.units = pipe.split_pipeline_units(models_require_backward)
191
+ elif task.endswith(":train"):
192
+ pipe.units, _ = pipe.split_pipeline_units(models_require_backward)
193
+ return pipe
194
+
195
+ def parse_extra_inputs(self, data, extra_inputs, inputs_shared):
196
+ controlnet_keys_map = (
197
+ ("blockwise_controlnet_", "blockwise_controlnet_inputs",),
198
+ ("controlnet_", "controlnet_inputs"),
199
+ )
200
+ controlnet_inputs = {}
201
+ for extra_input in extra_inputs:
202
+ for prefix, name in controlnet_keys_map:
203
+ if extra_input.startswith(prefix):
204
+ if name not in controlnet_inputs:
205
+ controlnet_inputs[name] = {}
206
+ controlnet_inputs[name][extra_input.replace(prefix, "")] = data[extra_input]
207
+ break
208
+ else:
209
+ inputs_shared[extra_input] = data[extra_input]
210
+ for name, params in controlnet_inputs.items():
211
+ inputs_shared[name] = [ControlNetInput(**params)]
212
+ return inputs_shared
diffsynth/models/comp_attn_model.py ADDED
@@ -0,0 +1,592 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+ from typing import Optional, Sequence
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+
8
+ from ..diffusion.base_pipeline import PipelineUnit
9
+
10
+
11
+ @dataclass
12
+ class CompAttnConfig:
13
+ subjects: Sequence[str]
14
+ bboxes: Optional[Sequence] = None
15
+ enable_sci: bool = True
16
+ enable_lam: bool = True
17
+ temperature: float = 0.2
18
+ apply_to_negative: bool = False
19
+ interpolate: bool = False
20
+ state_texts: Optional[Sequence[Sequence[str]]] = None
21
+ state_weights: Optional[Sequence] = None
22
+ state_scale: float = 1.0
23
+ state_template: str = "{subject} is {state}"
24
+
25
+
26
+ def find_subsequence_indices(prompt_ids: torch.Tensor, subject_ids: torch.Tensor, valid_len: int) -> list[int]:
27
+ if subject_ids.numel() == 0 or valid_len <= 0:
28
+ return []
29
+ prompt_slice = prompt_ids[:valid_len].tolist()
30
+ subject_list = subject_ids.tolist()
31
+ span = len(subject_list)
32
+ if span > valid_len:
33
+ return []
34
+ for start in range(valid_len - span + 1):
35
+ if prompt_slice[start:start + span] == subject_list:
36
+ return list(range(start, start + span))
37
+ return []
38
+
39
+
40
+ def build_subject_token_mask(indices_list: list[list[int]], seq_len: int) -> torch.Tensor:
41
+ mask = torch.zeros((len(indices_list), seq_len), dtype=torch.bool)
42
+ for i, indices in enumerate(indices_list):
43
+ if not indices:
44
+ continue
45
+ mask[i, torch.tensor(indices, dtype=torch.long)] = True
46
+ return mask
47
+
48
+
49
+ def compute_saliency(prompt_vecs: torch.Tensor, anchor_vecs: torch.Tensor, tau: float) -> torch.Tensor:
50
+ prompt_norm = prompt_vecs / (prompt_vecs.norm(dim=-1, keepdim=True) + 1e-8)
51
+ anchor_norm = anchor_vecs / (anchor_vecs.norm(dim=-1, keepdim=True) + 1e-8)
52
+ cosine = torch.matmul(prompt_norm, anchor_norm.transpose(0, 1))
53
+ scores = torch.exp(cosine / tau)
54
+ diag = scores.diagonal()
55
+ denom = scores.sum(dim=1).clamp(min=1e-8)
56
+ return diag / denom
57
+
58
+
59
+ def compute_delta(anchor_vecs: torch.Tensor) -> torch.Tensor:
60
+ total = anchor_vecs.sum(dim=0, keepdim=True)
61
+ return anchor_vecs * anchor_vecs.shape[0] - total
62
+
63
+
64
+ _sci_call_count = [0] # 使用列表以便在函数内修改
65
+
66
+ def apply_sci(context: torch.Tensor, state: dict, timestep: torch.Tensor) -> torch.Tensor:
67
+ if state is None or not state.get("enable_sci", False):
68
+ return context
69
+ subject_mask = state.get("subject_token_mask")
70
+ delta = state.get("delta")
71
+ saliency = state.get("saliency")
72
+ if subject_mask is None or delta is None or saliency is None:
73
+ return context
74
+ if subject_mask.numel() == 0:
75
+ return context
76
+ t_scale = float(state.get("timestep_scale", 1000.0))
77
+ t_value = float(timestep.reshape(-1)[0].item())
78
+ t_ratio = max(0.0, min(1.0, t_value / t_scale))
79
+ omega = 1.0 - t_ratio
80
+ delta = delta.to(device=context.device, dtype=context.dtype)
81
+ saliency = saliency.to(device=context.device, dtype=context.dtype)
82
+ scale = omega * (1.0 - saliency).unsqueeze(-1)
83
+ delta = delta * scale
84
+ mask = subject_mask.to(device=context.device)
85
+ token_delta = torch.matmul(mask.to(dtype=context.dtype).transpose(0, 1), delta)
86
+ apply_mask = state.get("apply_mask")
87
+ if apply_mask is not None:
88
+ apply_mask = apply_mask.to(device=context.device, dtype=context.dtype).view(-1, 1, 1)
89
+ else:
90
+ apply_mask = 1.0
91
+
92
+ # ========== DEBUG: 打印 SCI 信息 ==========
93
+ _sci_call_count[0] += 1
94
+ if _sci_call_count[0] % 100 == 1:
95
+ print(f"\n{'='*60}")
96
+ print(f"[SCI (Saliency-Controlled Intervention) #{_sci_call_count[0]}]")
97
+ print(f" timestep: {t_value:.2f}, t_ratio: {t_ratio:.4f}, omega: {omega:.4f}")
98
+ print(f" saliency per subject: {saliency.tolist()}")
99
+ print(f" delta shape: {delta.shape}")
100
+ print(f" delta norm per subject: {delta.norm(dim=-1).tolist()}")
101
+ print(f" token_delta shape: {token_delta.shape}")
102
+ print(f" context modification norm: {(token_delta.unsqueeze(0) * apply_mask).norm().item():.6f}")
103
+ print(f"{'='*60}\n")
104
+
105
+ return context + token_delta.unsqueeze(0) * apply_mask
106
+
107
+
108
+ def interpolate_bboxes(bboxes: torch.Tensor, target_frames: int) -> torch.Tensor:
109
+ if bboxes.shape[2] == target_frames:
110
+ return bboxes
111
+ b, m, f, _ = bboxes.shape
112
+ coords = bboxes.reshape(b * m, f, 4).transpose(1, 2)
113
+ coords = F.interpolate(coords, size=target_frames, mode="linear", align_corners=True)
114
+ coords = coords.transpose(1, 2).reshape(b, m, target_frames, 4)
115
+ return coords
116
+
117
+
118
+ def build_layout_mask_from_bboxes(
119
+ bboxes: torch.Tensor,
120
+ grid_size: tuple[int, int, int],
121
+ image_size: tuple[int, int],
122
+ device: torch.device,
123
+ dtype: torch.dtype,
124
+ ) -> torch.Tensor:
125
+ if bboxes is None:
126
+ return None
127
+ bboxes = bboxes.to(device=device, dtype=dtype)
128
+ b, m, f_layout, _ = bboxes.shape
129
+ f_grid, h_grid, w_grid = grid_size
130
+ height, width = image_size
131
+ layout = torch.zeros((b, m, f_grid, h_grid, w_grid), device=device, dtype=dtype)
132
+ for bi in range(b):
133
+ for mi in range(m):
134
+ for ti in range(f_layout):
135
+ pt = int(ti * f_grid / max(1, f_layout))
136
+ pt = max(0, min(f_grid - 1, pt))
137
+ x0, y0, x1, y1 = bboxes[bi, mi, ti]
138
+ x0 = float(x0)
139
+ y0 = float(y0)
140
+ x1 = float(x1)
141
+ y1 = float(y1)
142
+ if x1 <= x0 or y1 <= y0:
143
+ continue
144
+ px0 = int(math.floor(x0 / max(1.0, width) * w_grid))
145
+ px1 = int(math.ceil(x1 / max(1.0, width) * w_grid))
146
+ py0 = int(math.floor(y0 / max(1.0, height) * h_grid))
147
+ py1 = int(math.ceil(y1 / max(1.0, height) * h_grid))
148
+ px0 = max(0, min(w_grid, px0))
149
+ px1 = max(0, min(w_grid, px1))
150
+ py0 = max(0, min(h_grid, py0))
151
+ py1 = max(0, min(h_grid, py1))
152
+ if px1 <= px0 or py1 <= py0:
153
+ continue
154
+ layout[bi, mi, pt, py0:py1, px0:px1] = 1.0
155
+ return layout.flatten(2)
156
+
157
+
158
+ _lam_attention_call_count = [0] # 使用列表以便在函数内修改
159
+
160
+ def lam_attention(
161
+ q: torch.Tensor,
162
+ k: torch.Tensor,
163
+ v: torch.Tensor,
164
+ num_heads: int,
165
+ state: dict,
166
+ ) -> Optional[torch.Tensor]:
167
+ subject_mask = state.get("subject_token_mask_lam")
168
+ if subject_mask is None:
169
+ subject_mask = state.get("subject_token_mask")
170
+ layout_mask = state.get("layout_mask")
171
+ state_token_mask = state.get("state_token_mask")
172
+ state_token_weights = state.get("state_token_weights")
173
+ state_scale = float(state.get("state_scale", 1.0))
174
+ grid_shape = state.get("grid_shape")
175
+ enable_lam = bool(state.get("enable_lam", False))
176
+ enable_state = state_token_mask is not None and state_token_weights is not None and grid_shape is not None
177
+ if not enable_lam and not enable_state:
178
+ return None
179
+ b, q_len, dim = q.shape
180
+ _, k_len, _ = k.shape
181
+ if enable_lam:
182
+ if subject_mask is None or layout_mask is None:
183
+ return None
184
+ if subject_mask.numel() == 0 or layout_mask.numel() == 0:
185
+ return None
186
+ if layout_mask.shape[-1] != q_len:
187
+ return None
188
+ if subject_mask.shape[-1] != k_len:
189
+ return None
190
+ if enable_state:
191
+ if state_token_mask.shape[-1] != k_len:
192
+ return None
193
+ head_dim = dim // num_heads
194
+ qh = q.view(b, q_len, num_heads, head_dim).transpose(1, 2)
195
+ kh = k.view(b, k_len, num_heads, head_dim).transpose(1, 2)
196
+ vh = v.view(b, k_len, num_heads, head_dim).transpose(1, 2)
197
+ attn_scores = torch.matmul(qh.float(), kh.float().transpose(-2, -1)) / math.sqrt(head_dim)
198
+
199
+ # ========== DEBUG: 打印 attention map 信息 ==========
200
+ _lam_attention_call_count[0] += 1
201
+ call_id = _lam_attention_call_count[0]
202
+ # 每100次调用打印一次,避免输出过多
203
+ if call_id % 100 == 1:
204
+ print(f"\n{'='*60}")
205
+ print(f"[LAM Attention #{call_id}]")
206
+ print(f" Q shape: {q.shape}, K shape: {k.shape}, V shape: {v.shape}")
207
+ print(f" num_heads: {num_heads}, head_dim: {head_dim}")
208
+ print(f" attn_scores shape: {attn_scores.shape}")
209
+ print(f" attn_scores stats: min={attn_scores.min().item():.4f}, max={attn_scores.max().item():.4f}, mean={attn_scores.mean().item():.4f}")
210
+ if enable_lam and layout_mask is not None:
211
+ print(f" layout_mask shape: {layout_mask.shape}")
212
+ print(f" layout_mask sum per subject: {layout_mask.sum(dim=-1)}")
213
+ if subject_mask is not None:
214
+ print(f" subject_token_mask shape: {subject_mask.shape}")
215
+ print(f" subject_token_mask active tokens per subject: {subject_mask.sum(dim=-1).tolist()}")
216
+ if grid_shape is not None:
217
+ print(f" grid_shape (f, h, w): {grid_shape}")
218
+ print(f"{'='*60}")
219
+ bias = torch.zeros_like(attn_scores)
220
+ if enable_lam:
221
+ attn_max = attn_scores.max(dim=-1, keepdim=True).values
222
+ attn_min = attn_scores.min(dim=-1, keepdim=True).values
223
+ g_plus = attn_max - attn_scores
224
+ g_minus = attn_min - attn_scores
225
+ subject_mask = subject_mask.to(device=attn_scores.device)
226
+ layout_mask = layout_mask.to(device=attn_scores.device, dtype=attn_scores.dtype)
227
+ apply_mask = state.get("apply_mask")
228
+ if apply_mask is not None:
229
+ layout_mask = layout_mask * apply_mask.to(device=layout_mask.device, dtype=layout_mask.dtype).view(-1, 1, 1)
230
+ subject_any = subject_mask.any(dim=0)
231
+ for k_idx in range(subject_mask.shape[0]):
232
+ mask_k = subject_mask[k_idx]
233
+ if not mask_k.any():
234
+ continue
235
+ mask_other = subject_any & (~mask_k)
236
+ mask_k = mask_k.to(dtype=attn_scores.dtype).view(1, 1, 1, k_len)
237
+ mask_other = mask_other.to(dtype=attn_scores.dtype).view(1, 1, 1, k_len)
238
+ g_k = g_plus * mask_k + g_minus * mask_other
239
+ attn_k = attn_scores[..., subject_mask[k_idx]].mean(dim=-1).mean(dim=1)
240
+ adapt_mask = attn_k >= attn_k.mean(dim=-1, keepdim=True)
241
+ layout_k = layout_mask[:, k_idx]
242
+ adapt_f = adapt_mask.to(layout_k.dtype)
243
+ inter = (adapt_f * layout_k).sum(dim=-1)
244
+ union = (adapt_f + layout_k - adapt_f * layout_k).sum(dim=-1)
245
+ iou = inter / union.clamp(min=1e-6)
246
+ strength = (1.0 - iou).view(b, 1, 1, 1)
247
+ bias = bias + g_k * strength * layout_k.view(b, 1, q_len, 1)
248
+ if enable_state:
249
+ f, h, w = grid_shape
250
+ if f * h * w != q_len:
251
+ return None
252
+ state_token_mask = state_token_mask.to(device=attn_scores.device)
253
+ state_indices = torch.nonzero(state_token_mask, as_tuple=False).flatten()
254
+ if state_indices.numel() == 0:
255
+ return None
256
+ weights = state_token_weights.to(device=attn_scores.device, dtype=attn_scores.dtype)
257
+ if weights.shape[1] != f:
258
+ return None
259
+ time_index = torch.arange(q_len, device=attn_scores.device) // (h * w)
260
+ weights_q = weights[:, time_index, :]
261
+ if weights_q.shape[-1] != state_indices.numel():
262
+ return None
263
+ state_bias = torch.zeros((b, 1, q_len, k_len), device=attn_scores.device, dtype=attn_scores.dtype)
264
+ state_bias[:, :, :, state_indices] = weights_q.unsqueeze(1) * state_scale
265
+ bias = bias + state_bias
266
+ attn_probs = torch.softmax(attn_scores + bias, dim=-1).to(vh.dtype)
267
+
268
+ # ========== DEBUG: 打印 attention probs 和 bias 信息 ==========
269
+ if _lam_attention_call_count[0] % 100 == 1:
270
+ print(f"\n[LAM Attention #{_lam_attention_call_count[0]} - After Bias]")
271
+ print(f" bias shape: {bias.shape}")
272
+ print(f" bias stats: min={bias.min().item():.4f}, max={bias.max().item():.4f}, mean={bias.mean().item():.4f}")
273
+ print(f" bias non-zero ratio: {(bias != 0).float().mean().item():.4f}")
274
+ print(f" attn_probs shape: {attn_probs.shape}")
275
+ print(f" attn_probs stats: min={attn_probs.min().item():.6f}, max={attn_probs.max().item():.6f}")
276
+ # 打印每个 subject 对应 token 的平均 attention weight
277
+ if subject_mask is not None:
278
+ for subj_idx in range(subject_mask.shape[0]):
279
+ mask_k = subject_mask[subj_idx]
280
+ if mask_k.any():
281
+ # 计算所有 query 对该 subject tokens 的平均 attention
282
+ subj_attn = attn_probs[:, :, :, mask_k.to(attn_probs.device)].mean()
283
+ print(f" Subject {subj_idx} avg attention weight: {subj_attn.item():.6f}")
284
+ print(f"{'='*60}\n")
285
+
286
+ out = torch.matmul(attn_probs, vh)
287
+ out = out.transpose(1, 2).reshape(b, q_len, dim)
288
+ return out
289
+
290
+
291
+ class CompAttnUnit(PipelineUnit):
292
+ def __init__(self):
293
+ super().__init__(
294
+ seperate_cfg=True,
295
+ input_params_posi={"prompt": "prompt", "context": "context"},
296
+ input_params_nega={"prompt": "negative_prompt", "context": "context"},
297
+ output_params=("comp_attn_state",),
298
+ onload_model_names=("text_encoder",),
299
+ )
300
+
301
+ def _clean_text(self, pipe, text: str) -> str:
302
+ if getattr(pipe.tokenizer, "clean", None):
303
+ return pipe.tokenizer._clean(text)
304
+ return text
305
+
306
+ def _tokenize_subject(self, pipe, text: str) -> torch.Tensor:
307
+ text = self._clean_text(pipe, text)
308
+ tokens = pipe.tokenizer.tokenizer(text, add_special_tokens=False, return_tensors="pt")
309
+ return tokens["input_ids"][0]
310
+
311
+ def _normalize_bboxes(self, bboxes: Sequence) -> torch.Tensor:
312
+ bboxes = torch.as_tensor(bboxes, dtype=torch.float32)
313
+ if bboxes.dim() == 2 and bboxes.shape[-1] == 4:
314
+ bboxes = bboxes.unsqueeze(0).unsqueeze(0)
315
+ elif bboxes.dim() == 3 and bboxes.shape[-1] == 4:
316
+ bboxes = bboxes.unsqueeze(0)
317
+ elif bboxes.dim() != 4 or bboxes.shape[-1] != 4:
318
+ raise ValueError(f"comp_attn_bboxes must be (..., 4), got shape {tuple(bboxes.shape)}")
319
+ return bboxes
320
+
321
+ def process(self, pipe, prompt, context) -> dict:
322
+ config: Optional[CompAttnConfig] = getattr(pipe, "_comp_attn_config", None)
323
+ if context is None or prompt is None or config is None:
324
+ return {}
325
+ if not config.subjects:
326
+ return {}
327
+ negative_prompt = getattr(pipe, "_comp_attn_last_negative_prompt", None)
328
+ if (not config.apply_to_negative) and negative_prompt and prompt == negative_prompt:
329
+ return {}
330
+ pipe.load_models_to_device(self.onload_model_names)
331
+ ids, mask = pipe.tokenizer(prompt, return_mask=True, add_special_tokens=True)
332
+ prompt_ids = ids[0]
333
+ valid_len = int(mask[0].sum().item())
334
+ indices_list = []
335
+ valid_subjects = []
336
+ for idx, subject in enumerate(config.subjects):
337
+ subject_ids = self._tokenize_subject(pipe, subject)
338
+ indices = find_subsequence_indices(prompt_ids, subject_ids, valid_len)
339
+ if not indices:
340
+ print(f"Comp-Attn: subject tokens not found in prompt: {subject}")
341
+ continue
342
+ indices_list.append(indices)
343
+ valid_subjects.append(idx)
344
+ if not indices_list:
345
+ return {}
346
+ subject_token_mask = build_subject_token_mask(indices_list, prompt_ids.shape[0]).to(device=context.device)
347
+ mask_float = subject_token_mask.to(dtype=context.dtype)
348
+ denom = mask_float.sum(dim=1, keepdim=True).clamp(min=1)
349
+ prompt_vecs = (mask_float @ context[0]) / denom
350
+ anchor_vecs = []
351
+ for idx in valid_subjects:
352
+ subject = config.subjects[idx]
353
+ sub_ids, sub_mask = pipe.tokenizer(subject, return_mask=True, add_special_tokens=True)
354
+ sub_ids = sub_ids.to(pipe.device)
355
+ sub_mask = sub_mask.to(pipe.device)
356
+ emb = pipe.text_encoder(sub_ids, sub_mask)
357
+ pooled = (emb * sub_mask.unsqueeze(-1)).sum(dim=1) / sub_mask.sum(dim=1, keepdim=True).clamp(min=1)
358
+ anchor_vecs.append(pooled)
359
+ anchor_vecs = torch.cat(anchor_vecs, dim=0)
360
+ saliency = compute_saliency(prompt_vecs.float(), anchor_vecs.float(), float(config.temperature)).to(prompt_vecs.dtype)
361
+ delta = compute_delta(anchor_vecs.to(prompt_vecs.dtype))
362
+ bboxes = None
363
+ state_vectors = None
364
+ state_weights = None
365
+ state_len = 0
366
+ if config.bboxes is not None:
367
+ bboxes = self._normalize_bboxes(config.bboxes)
368
+ if bboxes.shape[1] >= len(config.subjects):
369
+ bboxes = bboxes[:, valid_subjects]
370
+ if bboxes.shape[1] != len(valid_subjects):
371
+ print("Comp-Attn: bboxes subject count mismatch, disable LAM")
372
+ bboxes = None
373
+ if bboxes is not None and config.interpolate and getattr(pipe, "_comp_attn_num_frames", None) is not None:
374
+ bboxes = interpolate_bboxes(bboxes, int(pipe._comp_attn_num_frames))
375
+ if config.state_texts is not None and config.state_weights is not None:
376
+ state_texts = config.state_texts
377
+ if len(valid_subjects) != len(config.subjects):
378
+ subject_names = [config.subjects[i] for i in valid_subjects]
379
+ state_texts = [state_texts[i] for i in valid_subjects]
380
+ else:
381
+ subject_names = list(config.subjects)
382
+ if len(state_texts) != len(subject_names):
383
+ raise ValueError("state_texts must align with subjects")
384
+ state_count = len(state_texts[0])
385
+ for row in state_texts:
386
+ if len(row) != state_count:
387
+ raise ValueError("state_texts must have the same number of states per subject")
388
+ phrases = []
389
+ for subject, states in zip(subject_names, state_texts):
390
+ for state in states:
391
+ phrases.append(config.state_template.format(subject=subject, state=state))
392
+ ids, mask = pipe.tokenizer(phrases, return_mask=True, add_special_tokens=True)
393
+ ids = ids.to(pipe.device)
394
+ mask = mask.to(pipe.device)
395
+ emb = pipe.text_encoder(ids, mask)
396
+ pooled = (emb * mask.unsqueeze(-1)).sum(dim=1) / mask.sum(dim=1, keepdim=True).clamp(min=1)
397
+ state_vectors = pooled.to(dtype=prompt_vecs.dtype, device="cpu")
398
+ state_len = state_vectors.shape[0]
399
+ weights = torch.as_tensor(config.state_weights, dtype=torch.float32)
400
+ if weights.dim() == 3:
401
+ weights = weights.unsqueeze(0)
402
+ if weights.dim() != 4:
403
+ raise ValueError("state_weights must be (M,F,S) or (B,M,F,S)")
404
+ if weights.shape[1] >= len(config.subjects) and len(valid_subjects) != len(config.subjects):
405
+ weights = weights[:, valid_subjects]
406
+ if weights.shape[1] != len(subject_names) or weights.shape[3] != state_count:
407
+ raise ValueError("state_weights shape does not match state_texts")
408
+ weights = weights[:, :len(subject_names)]
409
+ weights = weights.permute(0, 2, 1, 3).contiguous()
410
+ weights = weights.reshape(weights.shape[0], weights.shape[1], weights.shape[2] * weights.shape[3])
411
+ state_weights = weights.to(device="cpu")
412
+ state = {
413
+ "enable_sci": bool(config.enable_sci),
414
+ "enable_lam": bool(config.enable_lam) and bboxes is not None,
415
+ "subject_token_mask": subject_token_mask,
416
+ "saliency": saliency,
417
+ "delta": delta,
418
+ "layout_bboxes": bboxes,
419
+ "state_vectors": state_vectors,
420
+ "state_weights": state_weights,
421
+ "state_scale": float(config.state_scale),
422
+ "prompt_len": int(prompt_ids.shape[0]),
423
+ "state_len": int(state_len),
424
+ "timestep_scale": 1000.0,
425
+ "apply_to_negative": bool(config.apply_to_negative),
426
+ }
427
+ if negative_prompt and prompt == negative_prompt:
428
+ pipe._comp_attn_state_neg = state
429
+ else:
430
+ pipe._comp_attn_state_pos = state
431
+ return {"comp_attn_state": state}
432
+
433
+
434
+ class CompAttnMergeUnit(PipelineUnit):
435
+ def __init__(self):
436
+ super().__init__(input_params=("cfg_merge",), output_params=("comp_attn_state",))
437
+
438
+ def process(self, pipe, cfg_merge) -> dict:
439
+ if not cfg_merge:
440
+ return {}
441
+ state_pos = getattr(pipe, "_comp_attn_state_pos", None)
442
+ state_neg = getattr(pipe, "_comp_attn_state_neg", None)
443
+ merged = state_pos or state_neg
444
+ if merged is None:
445
+ return {}
446
+ merged = dict(merged)
447
+ apply_to_negative = bool(merged.get("apply_to_negative", False))
448
+ merged["apply_mask"] = torch.tensor([1.0, 1.0 if apply_to_negative else 0.0])
449
+ return {"comp_attn_state": merged}
450
+
451
+
452
+ def patch_cross_attention(pipe) -> None:
453
+ for block in pipe.dit.blocks:
454
+ cross_attn = block.cross_attn
455
+ if getattr(cross_attn, "_comp_attn_patched", False):
456
+ continue
457
+ orig_forward = cross_attn.forward
458
+
459
+ def forward_with_lam(self, x, y, _orig=orig_forward, _pipe=pipe):
460
+ state = getattr(_pipe, "_comp_attn_runtime_state", None)
461
+ enable_lam = bool(state.get("enable_lam", False)) if state else False
462
+ enable_state = bool(state.get("state_token_weights") is not None) if state else False
463
+ if state is None or (not enable_lam and not enable_state):
464
+ return _orig(x, y)
465
+ if self.has_image_input:
466
+ img = y[:, :257]
467
+ ctx = y[:, 257:]
468
+ else:
469
+ ctx = y
470
+ q = self.norm_q(self.q(x))
471
+ k = self.norm_k(self.k(ctx))
472
+ v = self.v(ctx)
473
+ lam_out = lam_attention(q, k, v, self.num_heads, state)
474
+ if lam_out is None:
475
+ out = self.attn(q, k, v)
476
+ else:
477
+ out = lam_out
478
+ if self.has_image_input:
479
+ k_img = self.norm_k_img(self.k_img(img))
480
+ v_img = self.v_img(img)
481
+ img_out = self.attn(q, k_img, v_img)
482
+ out = out + img_out
483
+ return self.o(out)
484
+
485
+ cross_attn.forward = forward_with_lam.__get__(cross_attn, cross_attn.__class__)
486
+ cross_attn._comp_attn_patched = True
487
+
488
+
489
+ def get_grid_from_latents(latents: torch.Tensor, patch_size: tuple[int, int, int]) -> tuple[int, int, int]:
490
+ f = latents.shape[2] // patch_size[0]
491
+ h = latents.shape[3] // patch_size[1]
492
+ w = latents.shape[4] // patch_size[2]
493
+ return f, h, w
494
+
495
+
496
+ def wrap_model_fn(pipe) -> None:
497
+ if getattr(pipe, "_comp_attn_model_fn_patched", False):
498
+ return
499
+ orig_model_fn = pipe.model_fn
500
+
501
+ def model_fn_wrapper(*args, **kwargs):
502
+ comp_attn_state = kwargs.pop("comp_attn_state", None)
503
+ height = kwargs.get("height")
504
+ width = kwargs.get("width")
505
+ num_frames = kwargs.get("num_frames")
506
+ if num_frames is not None:
507
+ pipe._comp_attn_num_frames = num_frames
508
+ if comp_attn_state is None:
509
+ return orig_model_fn(*args, **kwargs)
510
+ latents = kwargs.get("latents")
511
+ timestep = kwargs.get("timestep")
512
+ context = kwargs.get("context")
513
+ clip_feature = kwargs.get("clip_feature")
514
+ reference_latents = kwargs.get("reference_latents")
515
+ state_vectors = comp_attn_state.get("state_vectors")
516
+ state_weights = comp_attn_state.get("state_weights")
517
+ state_len = int(comp_attn_state.get("state_len", 0))
518
+ prompt_len = int(comp_attn_state.get("prompt_len", context.shape[1] if context is not None else 0))
519
+ if context is not None and timestep is not None:
520
+ context = apply_sci(context, comp_attn_state, timestep)
521
+ if state_vectors is not None and state_len > 0:
522
+ state_vec = state_vectors.to(device=context.device, dtype=context.dtype)
523
+ if state_vec.dim() == 2:
524
+ state_vec = state_vec.unsqueeze(0)
525
+ if state_vec.shape[0] != context.shape[0]:
526
+ state_vec = state_vec.repeat(context.shape[0], 1, 1)
527
+ context = torch.cat([context, state_vec], dim=1)
528
+ kwargs["context"] = context
529
+ subject_mask = comp_attn_state.get("subject_token_mask")
530
+ if subject_mask is not None:
531
+ clip_len = clip_feature.shape[1] if clip_feature is not None and pipe.dit.require_clip_embedding else 0
532
+ pad_clip = torch.zeros((subject_mask.shape[0], clip_len), dtype=torch.bool)
533
+ pad_state = torch.zeros((subject_mask.shape[0], state_len), dtype=torch.bool)
534
+ comp_attn_state["subject_token_mask_lam"] = torch.cat([pad_clip, subject_mask.cpu(), pad_state], dim=1)
535
+ if state_vectors is not None and state_len > 0:
536
+ clip_len = clip_feature.shape[1] if clip_feature is not None and pipe.dit.require_clip_embedding else 0
537
+ pad_prompt = torch.zeros((state_len, clip_len + prompt_len), dtype=torch.bool)
538
+ ones_state = torch.ones((state_len, state_len), dtype=torch.bool)
539
+ state_token_mask = torch.cat([pad_prompt, ones_state], dim=1).any(dim=0)
540
+ comp_attn_state["state_token_mask"] = state_token_mask
541
+ if latents is not None and height is not None and width is not None:
542
+ f, h, w = get_grid_from_latents(latents, pipe.dit.patch_size)
543
+ if comp_attn_state.get("enable_lam", False):
544
+ q_len = f * h * w
545
+ if reference_latents is not None:
546
+ q_len = (f + 1) * h * w
547
+ layout_mask = comp_attn_state.get("layout_mask")
548
+ layout_shape = comp_attn_state.get("layout_shape")
549
+ if layout_mask is None or layout_shape != (latents.shape[0], q_len):
550
+ layout_mask = build_layout_mask_from_bboxes(
551
+ comp_attn_state.get("layout_bboxes"),
552
+ (f, h, w),
553
+ (int(height), int(width)),
554
+ device=latents.device,
555
+ dtype=latents.dtype,
556
+ )
557
+ if reference_latents is not None:
558
+ pad = torch.zeros((layout_mask.shape[0], layout_mask.shape[1], h * w), device=latents.device, dtype=latents.dtype)
559
+ layout_mask = torch.cat([pad, layout_mask], dim=-1)
560
+ if layout_mask.shape[0] != latents.shape[0]:
561
+ layout_mask = layout_mask.repeat(latents.shape[0], 1, 1)
562
+ comp_attn_state["layout_mask"] = layout_mask
563
+ comp_attn_state["layout_shape"] = (latents.shape[0], q_len)
564
+ if state_weights is not None:
565
+ weights = state_weights.to(device=latents.device, dtype=latents.dtype)
566
+ if weights.shape[0] != latents.shape[0]:
567
+ weights = weights.repeat(latents.shape[0], 1, 1)
568
+ if weights.shape[1] != f:
569
+ weights = weights.transpose(1, 2)
570
+ weights = F.interpolate(weights, size=f, mode="linear", align_corners=True)
571
+ weights = weights.transpose(1, 2)
572
+ if reference_latents is not None:
573
+ pad = torch.zeros((weights.shape[0], 1, weights.shape[2]), device=weights.device, dtype=weights.dtype)
574
+ weights = torch.cat([pad, weights], dim=1)
575
+ f = f + 1
576
+ comp_attn_state["state_token_weights"] = weights
577
+ comp_attn_state["grid_shape"] = (f, h, w)
578
+ if (
579
+ latents is not None
580
+ and latents.shape[0] == 2
581
+ and not comp_attn_state.get("apply_to_negative", False)
582
+ and "apply_mask" not in comp_attn_state
583
+ ):
584
+ comp_attn_state["apply_mask"] = torch.tensor([1.0, 0.0], device=latents.device, dtype=latents.dtype)
585
+ pipe._comp_attn_runtime_state = comp_attn_state
586
+ try:
587
+ return orig_model_fn(*args, **kwargs)
588
+ finally:
589
+ pipe._comp_attn_runtime_state = None
590
+
591
+ pipe.model_fn = model_fn_wrapper
592
+ pipe._comp_attn_model_fn_patched = True
diffsynth/models/dinov3_image_encoder.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import DINOv3ViTModel, DINOv3ViTImageProcessorFast
2
+ from transformers.models.dinov3_vit.modeling_dinov3_vit import DINOv3ViTConfig
3
+ import torch
4
+
5
+
6
+ class DINOv3ImageEncoder(DINOv3ViTModel):
7
+ def __init__(self):
8
+ config = DINOv3ViTConfig(
9
+ architectures = [
10
+ "DINOv3ViTModel"
11
+ ],
12
+ attention_dropout = 0.0,
13
+ drop_path_rate = 0.0,
14
+ dtype = "float32",
15
+ hidden_act = "silu",
16
+ hidden_size = 4096,
17
+ image_size = 224,
18
+ initializer_range = 0.02,
19
+ intermediate_size = 8192,
20
+ key_bias = False,
21
+ layer_norm_eps = 1e-05,
22
+ layerscale_value = 1.0,
23
+ mlp_bias = True,
24
+ model_type = "dinov3_vit",
25
+ num_attention_heads = 32,
26
+ num_channels = 3,
27
+ num_hidden_layers = 40,
28
+ num_register_tokens = 4,
29
+ patch_size = 16,
30
+ pos_embed_jitter = None,
31
+ pos_embed_rescale = 2.0,
32
+ pos_embed_shift = None,
33
+ proj_bias = True,
34
+ query_bias = False,
35
+ rope_theta = 100.0,
36
+ transformers_version = "4.56.1",
37
+ use_gated_mlp = True,
38
+ value_bias = False
39
+ )
40
+ super().__init__(config)
41
+ self.processor = DINOv3ViTImageProcessorFast(
42
+ crop_size = None,
43
+ data_format = "channels_first",
44
+ default_to_square = True,
45
+ device = None,
46
+ disable_grouping = None,
47
+ do_center_crop = None,
48
+ do_convert_rgb = None,
49
+ do_normalize = True,
50
+ do_rescale = True,
51
+ do_resize = True,
52
+ image_mean = [
53
+ 0.485,
54
+ 0.456,
55
+ 0.406
56
+ ],
57
+ image_processor_type = "DINOv3ViTImageProcessorFast",
58
+ image_std = [
59
+ 0.229,
60
+ 0.224,
61
+ 0.225
62
+ ],
63
+ input_data_format = None,
64
+ resample = 2,
65
+ rescale_factor = 0.00392156862745098,
66
+ return_tensors = None,
67
+ size = {
68
+ "height": 224,
69
+ "width": 224
70
+ }
71
+ )
72
+
73
+ def forward(self, image, torch_dtype=torch.bfloat16, device="cuda"):
74
+ inputs = self.processor(images=image, return_tensors="pt")
75
+ pixel_values = inputs["pixel_values"].to(dtype=torch_dtype, device=device)
76
+ bool_masked_pos = None
77
+ head_mask = None
78
+
79
+ pixel_values = pixel_values.to(torch_dtype)
80
+ hidden_states = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
81
+ position_embeddings = self.rope_embeddings(pixel_values)
82
+
83
+ for i, layer_module in enumerate(self.layer):
84
+ layer_head_mask = head_mask[i] if head_mask is not None else None
85
+ hidden_states = layer_module(
86
+ hidden_states,
87
+ attention_mask=layer_head_mask,
88
+ position_embeddings=position_embeddings,
89
+ )
90
+
91
+ sequence_output = self.norm(hidden_states)
92
+ pooled_output = sequence_output[:, 0, :]
93
+
94
+ return pooled_output
diffsynth/models/flux2_dit.py ADDED
@@ -0,0 +1,1057 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from typing import Any, Dict, List, Optional, Tuple, Union
3
+
4
+ import torch, math
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import numpy as np
8
+ from ..core.attention import attention_forward
9
+ from ..core.gradient import gradient_checkpoint_forward
10
+
11
+
12
+ def get_timestep_embedding(
13
+ timesteps: torch.Tensor,
14
+ embedding_dim: int,
15
+ flip_sin_to_cos: bool = False,
16
+ downscale_freq_shift: float = 1,
17
+ scale: float = 1,
18
+ max_period: int = 10000,
19
+ ) -> torch.Tensor:
20
+ """
21
+ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
22
+
23
+ Args
24
+ timesteps (torch.Tensor):
25
+ a 1-D Tensor of N indices, one per batch element. These may be fractional.
26
+ embedding_dim (int):
27
+ the dimension of the output.
28
+ flip_sin_to_cos (bool):
29
+ Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
30
+ downscale_freq_shift (float):
31
+ Controls the delta between frequencies between dimensions
32
+ scale (float):
33
+ Scaling factor applied to the embeddings.
34
+ max_period (int):
35
+ Controls the maximum frequency of the embeddings
36
+ Returns
37
+ torch.Tensor: an [N x dim] Tensor of positional embeddings.
38
+ """
39
+ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
40
+
41
+ half_dim = embedding_dim // 2
42
+ exponent = -math.log(max_period) * torch.arange(
43
+ start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
44
+ )
45
+ exponent = exponent / (half_dim - downscale_freq_shift)
46
+
47
+ emb = torch.exp(exponent)
48
+ emb = timesteps[:, None].float() * emb[None, :]
49
+
50
+ # scale embeddings
51
+ emb = scale * emb
52
+
53
+ # concat sine and cosine embeddings
54
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
55
+
56
+ # flip sine and cosine embeddings
57
+ if flip_sin_to_cos:
58
+ emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
59
+
60
+ # zero pad
61
+ if embedding_dim % 2 == 1:
62
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
63
+ return emb
64
+
65
+
66
+ class TimestepEmbedding(nn.Module):
67
+ def __init__(
68
+ self,
69
+ in_channels: int,
70
+ time_embed_dim: int,
71
+ act_fn: str = "silu",
72
+ out_dim: int = None,
73
+ post_act_fn: Optional[str] = None,
74
+ cond_proj_dim=None,
75
+ sample_proj_bias=True,
76
+ ):
77
+ super().__init__()
78
+
79
+ self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias)
80
+
81
+ if cond_proj_dim is not None:
82
+ self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
83
+ else:
84
+ self.cond_proj = None
85
+
86
+ self.act = torch.nn.SiLU()
87
+
88
+ if out_dim is not None:
89
+ time_embed_dim_out = out_dim
90
+ else:
91
+ time_embed_dim_out = time_embed_dim
92
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias)
93
+
94
+ if post_act_fn is None:
95
+ self.post_act = None
96
+
97
+ def forward(self, sample, condition=None):
98
+ if condition is not None:
99
+ sample = sample + self.cond_proj(condition)
100
+ sample = self.linear_1(sample)
101
+
102
+ if self.act is not None:
103
+ sample = self.act(sample)
104
+
105
+ sample = self.linear_2(sample)
106
+
107
+ if self.post_act is not None:
108
+ sample = self.post_act(sample)
109
+ return sample
110
+
111
+
112
+ class Timesteps(nn.Module):
113
+ def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1):
114
+ super().__init__()
115
+ self.num_channels = num_channels
116
+ self.flip_sin_to_cos = flip_sin_to_cos
117
+ self.downscale_freq_shift = downscale_freq_shift
118
+ self.scale = scale
119
+
120
+ def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
121
+ t_emb = get_timestep_embedding(
122
+ timesteps,
123
+ self.num_channels,
124
+ flip_sin_to_cos=self.flip_sin_to_cos,
125
+ downscale_freq_shift=self.downscale_freq_shift,
126
+ scale=self.scale,
127
+ )
128
+ return t_emb
129
+
130
+
131
+ class AdaLayerNormContinuous(nn.Module):
132
+ r"""
133
+ Adaptive normalization layer with a norm layer (layer_norm or rms_norm).
134
+
135
+ Args:
136
+ embedding_dim (`int`): Embedding dimension to use during projection.
137
+ conditioning_embedding_dim (`int`): Dimension of the input condition.
138
+ elementwise_affine (`bool`, defaults to `True`):
139
+ Boolean flag to denote if affine transformation should be applied.
140
+ eps (`float`, defaults to 1e-5): Epsilon factor.
141
+ bias (`bias`, defaults to `True`): Boolean flag to denote if bias should be use.
142
+ norm_type (`str`, defaults to `"layer_norm"`):
143
+ Normalization layer to use. Values supported: "layer_norm", "rms_norm".
144
+ """
145
+
146
+ def __init__(
147
+ self,
148
+ embedding_dim: int,
149
+ conditioning_embedding_dim: int,
150
+ # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
151
+ # because the output is immediately scaled and shifted by the projected conditioning embeddings.
152
+ # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
153
+ # However, this is how it was implemented in the original code, and it's rather likely you should
154
+ # set `elementwise_affine` to False.
155
+ elementwise_affine=True,
156
+ eps=1e-5,
157
+ bias=True,
158
+ norm_type="layer_norm",
159
+ ):
160
+ super().__init__()
161
+ self.silu = nn.SiLU()
162
+ self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
163
+ if norm_type == "layer_norm":
164
+ self.norm = nn.LayerNorm(embedding_dim, eps, elementwise_affine, bias)
165
+
166
+ def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
167
+ # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
168
+ emb = self.linear(self.silu(conditioning_embedding).to(x.dtype))
169
+ scale, shift = torch.chunk(emb, 2, dim=1)
170
+ x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
171
+ return x
172
+
173
+
174
+ def get_1d_rotary_pos_embed(
175
+ dim: int,
176
+ pos: Union[np.ndarray, int],
177
+ theta: float = 10000.0,
178
+ use_real=False,
179
+ linear_factor=1.0,
180
+ ntk_factor=1.0,
181
+ repeat_interleave_real=True,
182
+ freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux)
183
+ ):
184
+ """
185
+ Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
186
+
187
+ This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
188
+ index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
189
+ data type.
190
+
191
+ Args:
192
+ dim (`int`): Dimension of the frequency tensor.
193
+ pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
194
+ theta (`float`, *optional*, defaults to 10000.0):
195
+ Scaling factor for frequency computation. Defaults to 10000.0.
196
+ use_real (`bool`, *optional*):
197
+ If True, return real part and imaginary part separately. Otherwise, return complex numbers.
198
+ linear_factor (`float`, *optional*, defaults to 1.0):
199
+ Scaling factor for the context extrapolation. Defaults to 1.0.
200
+ ntk_factor (`float`, *optional*, defaults to 1.0):
201
+ Scaling factor for the NTK-Aware RoPE. Defaults to 1.0.
202
+ repeat_interleave_real (`bool`, *optional*, defaults to `True`):
203
+ If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`.
204
+ Otherwise, they are concateanted with themselves.
205
+ freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`):
206
+ the dtype of the frequency tensor.
207
+ Returns:
208
+ `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
209
+ """
210
+ assert dim % 2 == 0
211
+
212
+ if isinstance(pos, int):
213
+ pos = torch.arange(pos)
214
+ if isinstance(pos, np.ndarray):
215
+ pos = torch.from_numpy(pos) # type: ignore # [S]
216
+
217
+ theta = theta * ntk_factor
218
+ freqs = (
219
+ 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device) / dim)) / linear_factor
220
+ ) # [D/2]
221
+ freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
222
+ is_npu = freqs.device.type == "npu"
223
+ if is_npu:
224
+ freqs = freqs.float()
225
+ if use_real and repeat_interleave_real:
226
+ # flux, hunyuan-dit, cogvideox
227
+ freqs_cos = freqs.cos().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D]
228
+ freqs_sin = freqs.sin().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D]
229
+ return freqs_cos, freqs_sin
230
+ elif use_real:
231
+ # stable audio, allegro
232
+ freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
233
+ freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
234
+ return freqs_cos, freqs_sin
235
+ else:
236
+ # lumina
237
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
238
+ return freqs_cis
239
+
240
+
241
+ def apply_rotary_emb(
242
+ x: torch.Tensor,
243
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
244
+ use_real: bool = True,
245
+ use_real_unbind_dim: int = -1,
246
+ sequence_dim: int = 2,
247
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
248
+ """
249
+ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
250
+ to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
251
+ reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
252
+ tensors contain rotary embeddings and are returned as real tensors.
253
+
254
+ Args:
255
+ x (`torch.Tensor`):
256
+ Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
257
+ freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
258
+
259
+ Returns:
260
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
261
+ """
262
+ if use_real:
263
+ cos, sin = freqs_cis # [S, D]
264
+ if sequence_dim == 2:
265
+ cos = cos[None, None, :, :]
266
+ sin = sin[None, None, :, :]
267
+ elif sequence_dim == 1:
268
+ cos = cos[None, :, None, :]
269
+ sin = sin[None, :, None, :]
270
+ else:
271
+ raise ValueError(f"`sequence_dim={sequence_dim}` but should be 1 or 2.")
272
+
273
+ cos, sin = cos.to(x.device), sin.to(x.device)
274
+
275
+ if use_real_unbind_dim == -1:
276
+ # Used for flux, cogvideox, hunyuan-dit
277
+ x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, H, S, D//2]
278
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
279
+ elif use_real_unbind_dim == -2:
280
+ # Used for Stable Audio, OmniGen, CogView4 and Cosmos
281
+ x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, H, S, D//2]
282
+ x_rotated = torch.cat([-x_imag, x_real], dim=-1)
283
+ else:
284
+ raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
285
+
286
+ out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
287
+
288
+ return out
289
+ else:
290
+ # used for lumina
291
+ x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
292
+ freqs_cis = freqs_cis.unsqueeze(2)
293
+ x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
294
+
295
+ return x_out.type_as(x)
296
+
297
+ def _get_projections(attn: "Flux2Attention", hidden_states, encoder_hidden_states=None):
298
+ query = attn.to_q(hidden_states)
299
+ key = attn.to_k(hidden_states)
300
+ value = attn.to_v(hidden_states)
301
+
302
+ encoder_query = encoder_key = encoder_value = None
303
+ if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
304
+ encoder_query = attn.add_q_proj(encoder_hidden_states)
305
+ encoder_key = attn.add_k_proj(encoder_hidden_states)
306
+ encoder_value = attn.add_v_proj(encoder_hidden_states)
307
+
308
+ return query, key, value, encoder_query, encoder_key, encoder_value
309
+
310
+
311
+ def _get_fused_projections(attn: "Flux2Attention", hidden_states, encoder_hidden_states=None):
312
+ query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
313
+
314
+ encoder_query = encoder_key = encoder_value = (None,)
315
+ if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"):
316
+ encoder_query, encoder_key, encoder_value = attn.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1)
317
+
318
+ return query, key, value, encoder_query, encoder_key, encoder_value
319
+
320
+
321
+ def _get_qkv_projections(attn: "Flux2Attention", hidden_states, encoder_hidden_states=None):
322
+ return _get_projections(attn, hidden_states, encoder_hidden_states)
323
+
324
+
325
+ class Flux2SwiGLU(nn.Module):
326
+ """
327
+ Flux 2 uses a SwiGLU-style activation in the transformer feedforward sub-blocks, but with the linear projection
328
+ layer fused into the first linear layer of the FF sub-block. Thus, this module has no trainable parameters.
329
+ """
330
+
331
+ def __init__(self):
332
+ super().__init__()
333
+ self.gate_fn = nn.SiLU()
334
+
335
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
336
+ x1, x2 = x.chunk(2, dim=-1)
337
+ x = self.gate_fn(x1) * x2
338
+ return x
339
+
340
+
341
+ class Flux2FeedForward(nn.Module):
342
+ def __init__(
343
+ self,
344
+ dim: int,
345
+ dim_out: Optional[int] = None,
346
+ mult: float = 3.0,
347
+ inner_dim: Optional[int] = None,
348
+ bias: bool = False,
349
+ ):
350
+ super().__init__()
351
+ if inner_dim is None:
352
+ inner_dim = int(dim * mult)
353
+ dim_out = dim_out or dim
354
+
355
+ # Flux2SwiGLU will reduce the dimension by half
356
+ self.linear_in = nn.Linear(dim, inner_dim * 2, bias=bias)
357
+ self.act_fn = Flux2SwiGLU()
358
+ self.linear_out = nn.Linear(inner_dim, dim_out, bias=bias)
359
+
360
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
361
+ x = self.linear_in(x)
362
+ x = self.act_fn(x)
363
+ x = self.linear_out(x)
364
+ return x
365
+
366
+
367
+ class Flux2AttnProcessor:
368
+ _attention_backend = None
369
+ _parallel_config = None
370
+
371
+ def __init__(self):
372
+ if not hasattr(F, "scaled_dot_product_attention"):
373
+ raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
374
+
375
+ def __call__(
376
+ self,
377
+ attn: "Flux2Attention",
378
+ hidden_states: torch.Tensor,
379
+ encoder_hidden_states: torch.Tensor = None,
380
+ attention_mask: Optional[torch.Tensor] = None,
381
+ image_rotary_emb: Optional[torch.Tensor] = None,
382
+ ) -> torch.Tensor:
383
+ query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
384
+ attn, hidden_states, encoder_hidden_states
385
+ )
386
+
387
+ query = query.unflatten(-1, (attn.heads, -1))
388
+ key = key.unflatten(-1, (attn.heads, -1))
389
+ value = value.unflatten(-1, (attn.heads, -1))
390
+
391
+ query = attn.norm_q(query)
392
+ key = attn.norm_k(key)
393
+
394
+ if attn.added_kv_proj_dim is not None:
395
+ encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
396
+ encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
397
+ encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
398
+
399
+ encoder_query = attn.norm_added_q(encoder_query)
400
+ encoder_key = attn.norm_added_k(encoder_key)
401
+
402
+ query = torch.cat([encoder_query, query], dim=1)
403
+ key = torch.cat([encoder_key, key], dim=1)
404
+ value = torch.cat([encoder_value, value], dim=1)
405
+
406
+ if image_rotary_emb is not None:
407
+ query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
408
+ key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
409
+
410
+ hidden_states = attention_forward(
411
+ query,
412
+ key,
413
+ value,
414
+ q_pattern="b s n d", k_pattern="b s n d", v_pattern="b s n d", out_pattern="b s n d",
415
+ )
416
+ hidden_states = hidden_states.flatten(2, 3)
417
+ hidden_states = hidden_states.to(query.dtype)
418
+
419
+ if encoder_hidden_states is not None:
420
+ encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
421
+ [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
422
+ )
423
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
424
+
425
+ hidden_states = attn.to_out[0](hidden_states)
426
+ hidden_states = attn.to_out[1](hidden_states)
427
+
428
+ if encoder_hidden_states is not None:
429
+ return hidden_states, encoder_hidden_states
430
+ else:
431
+ return hidden_states
432
+
433
+
434
+ class Flux2Attention(torch.nn.Module):
435
+ _default_processor_cls = Flux2AttnProcessor
436
+ _available_processors = [Flux2AttnProcessor]
437
+
438
+ def __init__(
439
+ self,
440
+ query_dim: int,
441
+ heads: int = 8,
442
+ dim_head: int = 64,
443
+ dropout: float = 0.0,
444
+ bias: bool = False,
445
+ added_kv_proj_dim: Optional[int] = None,
446
+ added_proj_bias: Optional[bool] = True,
447
+ out_bias: bool = True,
448
+ eps: float = 1e-5,
449
+ out_dim: int = None,
450
+ elementwise_affine: bool = True,
451
+ processor=None,
452
+ ):
453
+ super().__init__()
454
+
455
+ self.head_dim = dim_head
456
+ self.inner_dim = out_dim if out_dim is not None else dim_head * heads
457
+ self.query_dim = query_dim
458
+ self.out_dim = out_dim if out_dim is not None else query_dim
459
+ self.heads = out_dim // dim_head if out_dim is not None else heads
460
+
461
+ self.use_bias = bias
462
+ self.dropout = dropout
463
+
464
+ self.added_kv_proj_dim = added_kv_proj_dim
465
+ self.added_proj_bias = added_proj_bias
466
+
467
+ self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
468
+ self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
469
+ self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
470
+
471
+ # QK Norm
472
+ self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
473
+ self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
474
+
475
+ self.to_out = torch.nn.ModuleList([])
476
+ self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
477
+ self.to_out.append(torch.nn.Dropout(dropout))
478
+
479
+ if added_kv_proj_dim is not None:
480
+ self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps)
481
+ self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps)
482
+ self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
483
+ self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
484
+ self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
485
+ self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias)
486
+
487
+ if processor is None:
488
+ processor = self._default_processor_cls()
489
+ self.processor = processor
490
+
491
+ def forward(
492
+ self,
493
+ hidden_states: torch.Tensor,
494
+ encoder_hidden_states: Optional[torch.Tensor] = None,
495
+ attention_mask: Optional[torch.Tensor] = None,
496
+ image_rotary_emb: Optional[torch.Tensor] = None,
497
+ **kwargs,
498
+ ) -> torch.Tensor:
499
+ attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
500
+ kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
501
+ return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
502
+
503
+
504
+ class Flux2ParallelSelfAttnProcessor:
505
+ _attention_backend = None
506
+ _parallel_config = None
507
+
508
+ def __init__(self):
509
+ if not hasattr(F, "scaled_dot_product_attention"):
510
+ raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
511
+
512
+ def __call__(
513
+ self,
514
+ attn: "Flux2ParallelSelfAttention",
515
+ hidden_states: torch.Tensor,
516
+ attention_mask: Optional[torch.Tensor] = None,
517
+ image_rotary_emb: Optional[torch.Tensor] = None,
518
+ ) -> torch.Tensor:
519
+ # Parallel in (QKV + MLP in) projection
520
+ hidden_states = attn.to_qkv_mlp_proj(hidden_states)
521
+ qkv, mlp_hidden_states = torch.split(
522
+ hidden_states, [3 * attn.inner_dim, attn.mlp_hidden_dim * attn.mlp_mult_factor], dim=-1
523
+ )
524
+
525
+ # Handle the attention logic
526
+ query, key, value = qkv.chunk(3, dim=-1)
527
+
528
+ query = query.unflatten(-1, (attn.heads, -1))
529
+ key = key.unflatten(-1, (attn.heads, -1))
530
+ value = value.unflatten(-1, (attn.heads, -1))
531
+
532
+ query = attn.norm_q(query)
533
+ key = attn.norm_k(key)
534
+
535
+ if image_rotary_emb is not None:
536
+ query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
537
+ key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
538
+
539
+ hidden_states = attention_forward(
540
+ query,
541
+ key,
542
+ value,
543
+ q_pattern="b s n d", k_pattern="b s n d", v_pattern="b s n d", out_pattern="b s n d",
544
+ )
545
+ hidden_states = hidden_states.flatten(2, 3)
546
+ hidden_states = hidden_states.to(query.dtype)
547
+
548
+ # Handle the feedforward (FF) logic
549
+ mlp_hidden_states = attn.mlp_act_fn(mlp_hidden_states)
550
+
551
+ # Concatenate and parallel output projection
552
+ hidden_states = torch.cat([hidden_states, mlp_hidden_states], dim=-1)
553
+ hidden_states = attn.to_out(hidden_states)
554
+
555
+ return hidden_states
556
+
557
+
558
+ class Flux2ParallelSelfAttention(torch.nn.Module):
559
+ """
560
+ Flux 2 parallel self-attention for the Flux 2 single-stream transformer blocks.
561
+
562
+ This implements a parallel transformer block, where the attention QKV projections are fused to the feedforward (FF)
563
+ input projections, and the attention output projections are fused to the FF output projections. See the [ViT-22B
564
+ paper](https://arxiv.org/abs/2302.05442) for a visual depiction of this type of transformer block.
565
+ """
566
+
567
+ _default_processor_cls = Flux2ParallelSelfAttnProcessor
568
+ _available_processors = [Flux2ParallelSelfAttnProcessor]
569
+ # Does not support QKV fusion as the QKV projections are always fused
570
+ _supports_qkv_fusion = False
571
+
572
+ def __init__(
573
+ self,
574
+ query_dim: int,
575
+ heads: int = 8,
576
+ dim_head: int = 64,
577
+ dropout: float = 0.0,
578
+ bias: bool = False,
579
+ out_bias: bool = True,
580
+ eps: float = 1e-5,
581
+ out_dim: int = None,
582
+ elementwise_affine: bool = True,
583
+ mlp_ratio: float = 4.0,
584
+ mlp_mult_factor: int = 2,
585
+ processor=None,
586
+ ):
587
+ super().__init__()
588
+
589
+ self.head_dim = dim_head
590
+ self.inner_dim = out_dim if out_dim is not None else dim_head * heads
591
+ self.query_dim = query_dim
592
+ self.out_dim = out_dim if out_dim is not None else query_dim
593
+ self.heads = out_dim // dim_head if out_dim is not None else heads
594
+
595
+ self.use_bias = bias
596
+ self.dropout = dropout
597
+
598
+ self.mlp_ratio = mlp_ratio
599
+ self.mlp_hidden_dim = int(query_dim * self.mlp_ratio)
600
+ self.mlp_mult_factor = mlp_mult_factor
601
+
602
+ # Fused QKV projections + MLP input projection
603
+ self.to_qkv_mlp_proj = torch.nn.Linear(
604
+ self.query_dim, self.inner_dim * 3 + self.mlp_hidden_dim * self.mlp_mult_factor, bias=bias
605
+ )
606
+ self.mlp_act_fn = Flux2SwiGLU()
607
+
608
+ # QK Norm
609
+ self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
610
+ self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
611
+
612
+ # Fused attention output projection + MLP output projection
613
+ self.to_out = torch.nn.Linear(self.inner_dim + self.mlp_hidden_dim, self.out_dim, bias=out_bias)
614
+
615
+ if processor is None:
616
+ processor = self._default_processor_cls()
617
+ self.processor = processor
618
+
619
+ def forward(
620
+ self,
621
+ hidden_states: torch.Tensor,
622
+ attention_mask: Optional[torch.Tensor] = None,
623
+ image_rotary_emb: Optional[torch.Tensor] = None,
624
+ **kwargs,
625
+ ) -> torch.Tensor:
626
+ attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
627
+ kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
628
+ return self.processor(self, hidden_states, attention_mask, image_rotary_emb, **kwargs)
629
+
630
+
631
+ class Flux2SingleTransformerBlock(nn.Module):
632
+ def __init__(
633
+ self,
634
+ dim: int,
635
+ num_attention_heads: int,
636
+ attention_head_dim: int,
637
+ mlp_ratio: float = 3.0,
638
+ eps: float = 1e-6,
639
+ bias: bool = False,
640
+ ):
641
+ super().__init__()
642
+
643
+ self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
644
+
645
+ # Note that the MLP in/out linear layers are fused with the attention QKV/out projections, respectively; this
646
+ # is often called a "parallel" transformer block. See the [ViT-22B paper](https://arxiv.org/abs/2302.05442)
647
+ # for a visual depiction of this type of transformer block.
648
+ self.attn = Flux2ParallelSelfAttention(
649
+ query_dim=dim,
650
+ dim_head=attention_head_dim,
651
+ heads=num_attention_heads,
652
+ out_dim=dim,
653
+ bias=bias,
654
+ out_bias=bias,
655
+ eps=eps,
656
+ mlp_ratio=mlp_ratio,
657
+ mlp_mult_factor=2,
658
+ processor=Flux2ParallelSelfAttnProcessor(),
659
+ )
660
+
661
+ def forward(
662
+ self,
663
+ hidden_states: torch.Tensor,
664
+ encoder_hidden_states: Optional[torch.Tensor],
665
+ temb_mod_params: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
666
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
667
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
668
+ split_hidden_states: bool = False,
669
+ text_seq_len: Optional[int] = None,
670
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
671
+ # If encoder_hidden_states is None, hidden_states is assumed to have encoder_hidden_states already
672
+ # concatenated
673
+ if encoder_hidden_states is not None:
674
+ text_seq_len = encoder_hidden_states.shape[1]
675
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
676
+
677
+ mod_shift, mod_scale, mod_gate = temb_mod_params
678
+
679
+ norm_hidden_states = self.norm(hidden_states)
680
+ norm_hidden_states = (1 + mod_scale) * norm_hidden_states + mod_shift
681
+
682
+ joint_attention_kwargs = joint_attention_kwargs or {}
683
+ attn_output = self.attn(
684
+ hidden_states=norm_hidden_states,
685
+ image_rotary_emb=image_rotary_emb,
686
+ **joint_attention_kwargs,
687
+ )
688
+
689
+ hidden_states = hidden_states + mod_gate * attn_output
690
+ if hidden_states.dtype == torch.float16:
691
+ hidden_states = hidden_states.clip(-65504, 65504)
692
+
693
+ if split_hidden_states:
694
+ encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:]
695
+ return encoder_hidden_states, hidden_states
696
+ else:
697
+ return hidden_states
698
+
699
+
700
+ class Flux2TransformerBlock(nn.Module):
701
+ def __init__(
702
+ self,
703
+ dim: int,
704
+ num_attention_heads: int,
705
+ attention_head_dim: int,
706
+ mlp_ratio: float = 3.0,
707
+ eps: float = 1e-6,
708
+ bias: bool = False,
709
+ ):
710
+ super().__init__()
711
+ self.mlp_hidden_dim = int(dim * mlp_ratio)
712
+
713
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
714
+ self.norm1_context = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
715
+
716
+ self.attn = Flux2Attention(
717
+ query_dim=dim,
718
+ added_kv_proj_dim=dim,
719
+ dim_head=attention_head_dim,
720
+ heads=num_attention_heads,
721
+ out_dim=dim,
722
+ bias=bias,
723
+ added_proj_bias=bias,
724
+ out_bias=bias,
725
+ eps=eps,
726
+ processor=Flux2AttnProcessor(),
727
+ )
728
+
729
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
730
+ self.ff = Flux2FeedForward(dim=dim, dim_out=dim, mult=mlp_ratio, bias=bias)
731
+
732
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
733
+ self.ff_context = Flux2FeedForward(dim=dim, dim_out=dim, mult=mlp_ratio, bias=bias)
734
+
735
+ def forward(
736
+ self,
737
+ hidden_states: torch.Tensor,
738
+ encoder_hidden_states: torch.Tensor,
739
+ temb_mod_params_img: Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...],
740
+ temb_mod_params_txt: Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...],
741
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
742
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
743
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
744
+ joint_attention_kwargs = joint_attention_kwargs or {}
745
+
746
+ # Modulation parameters shape: [1, 1, self.dim]
747
+ (shift_msa, scale_msa, gate_msa), (shift_mlp, scale_mlp, gate_mlp) = temb_mod_params_img
748
+ (c_shift_msa, c_scale_msa, c_gate_msa), (c_shift_mlp, c_scale_mlp, c_gate_mlp) = temb_mod_params_txt
749
+
750
+ # Img stream
751
+ norm_hidden_states = self.norm1(hidden_states)
752
+ norm_hidden_states = (1 + scale_msa) * norm_hidden_states + shift_msa
753
+
754
+ # Conditioning txt stream
755
+ norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states)
756
+ norm_encoder_hidden_states = (1 + c_scale_msa) * norm_encoder_hidden_states + c_shift_msa
757
+
758
+ # Attention on concatenated img + txt stream
759
+ attention_outputs = self.attn(
760
+ hidden_states=norm_hidden_states,
761
+ encoder_hidden_states=norm_encoder_hidden_states,
762
+ image_rotary_emb=image_rotary_emb,
763
+ **joint_attention_kwargs,
764
+ )
765
+
766
+ attn_output, context_attn_output = attention_outputs
767
+
768
+ # Process attention outputs for the image stream (`hidden_states`).
769
+ attn_output = gate_msa * attn_output
770
+ hidden_states = hidden_states + attn_output
771
+
772
+ norm_hidden_states = self.norm2(hidden_states)
773
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
774
+
775
+ ff_output = self.ff(norm_hidden_states)
776
+ hidden_states = hidden_states + gate_mlp * ff_output
777
+
778
+ # Process attention outputs for the text stream (`encoder_hidden_states`).
779
+ context_attn_output = c_gate_msa * context_attn_output
780
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
781
+
782
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
783
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp) + c_shift_mlp
784
+
785
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
786
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output
787
+ if encoder_hidden_states.dtype == torch.float16:
788
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
789
+
790
+ return encoder_hidden_states, hidden_states
791
+
792
+
793
+ class Flux2PosEmbed(nn.Module):
794
+ # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
795
+ def __init__(self, theta: int, axes_dim: List[int]):
796
+ super().__init__()
797
+ self.theta = theta
798
+ self.axes_dim = axes_dim
799
+
800
+ def forward(self, ids: torch.Tensor) -> torch.Tensor:
801
+ # Expected ids shape: [S, len(self.axes_dim)]
802
+ cos_out = []
803
+ sin_out = []
804
+ pos = ids.float()
805
+ is_mps = ids.device.type == "mps"
806
+ is_npu = ids.device.type == "npu"
807
+ freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
808
+ # Unlike Flux 1, loop over len(self.axes_dim) rather than ids.shape[-1]
809
+ for i in range(len(self.axes_dim)):
810
+ cos, sin = get_1d_rotary_pos_embed(
811
+ self.axes_dim[i],
812
+ pos[..., i],
813
+ theta=self.theta,
814
+ repeat_interleave_real=True,
815
+ use_real=True,
816
+ freqs_dtype=freqs_dtype,
817
+ )
818
+ cos_out.append(cos)
819
+ sin_out.append(sin)
820
+ freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
821
+ freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
822
+ return freqs_cos, freqs_sin
823
+
824
+
825
+ class Flux2TimestepGuidanceEmbeddings(nn.Module):
826
+ def __init__(self, in_channels: int = 256, embedding_dim: int = 6144, bias: bool = False):
827
+ super().__init__()
828
+
829
+ self.time_proj = Timesteps(num_channels=in_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
830
+ self.timestep_embedder = TimestepEmbedding(
831
+ in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias
832
+ )
833
+
834
+ self.guidance_embedder = TimestepEmbedding(
835
+ in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias
836
+ )
837
+
838
+ def forward(self, timestep: torch.Tensor, guidance: torch.Tensor) -> torch.Tensor:
839
+ timesteps_proj = self.time_proj(timestep)
840
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(timestep.dtype)) # (N, D)
841
+
842
+ guidance_proj = self.time_proj(guidance)
843
+ guidance_emb = self.guidance_embedder(guidance_proj.to(guidance.dtype)) # (N, D)
844
+
845
+ time_guidance_emb = timesteps_emb + guidance_emb
846
+
847
+ return time_guidance_emb
848
+
849
+
850
+ class Flux2Modulation(nn.Module):
851
+ def __init__(self, dim: int, mod_param_sets: int = 2, bias: bool = False):
852
+ super().__init__()
853
+ self.mod_param_sets = mod_param_sets
854
+
855
+ self.linear = nn.Linear(dim, dim * 3 * self.mod_param_sets, bias=bias)
856
+ self.act_fn = nn.SiLU()
857
+
858
+ def forward(self, temb: torch.Tensor) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...]:
859
+ mod = self.act_fn(temb)
860
+ mod = self.linear(mod)
861
+
862
+ if mod.ndim == 2:
863
+ mod = mod.unsqueeze(1)
864
+ mod_params = torch.chunk(mod, 3 * self.mod_param_sets, dim=-1)
865
+ # Return tuple of 3-tuples of modulation params shift/scale/gate
866
+ return tuple(mod_params[3 * i : 3 * (i + 1)] for i in range(self.mod_param_sets))
867
+
868
+
869
+ class Flux2DiT(torch.nn.Module):
870
+ def __init__(
871
+ self,
872
+ patch_size: int = 1,
873
+ in_channels: int = 128,
874
+ out_channels: Optional[int] = None,
875
+ num_layers: int = 8,
876
+ num_single_layers: int = 48,
877
+ attention_head_dim: int = 128,
878
+ num_attention_heads: int = 48,
879
+ joint_attention_dim: int = 15360,
880
+ timestep_guidance_channels: int = 256,
881
+ mlp_ratio: float = 3.0,
882
+ axes_dims_rope: Tuple[int, ...] = (32, 32, 32, 32),
883
+ rope_theta: int = 2000,
884
+ eps: float = 1e-6,
885
+ ):
886
+ super().__init__()
887
+ self.out_channels = out_channels or in_channels
888
+ self.inner_dim = num_attention_heads * attention_head_dim
889
+
890
+ # 1. Sinusoidal positional embedding for RoPE on image and text tokens
891
+ self.pos_embed = Flux2PosEmbed(theta=rope_theta, axes_dim=axes_dims_rope)
892
+
893
+ # 2. Combined timestep + guidance embedding
894
+ self.time_guidance_embed = Flux2TimestepGuidanceEmbeddings(
895
+ in_channels=timestep_guidance_channels, embedding_dim=self.inner_dim, bias=False
896
+ )
897
+
898
+ # 3. Modulation (double stream and single stream blocks share modulation parameters, resp.)
899
+ # Two sets of shift/scale/gate modulation parameters for the double stream attn and FF sub-blocks
900
+ self.double_stream_modulation_img = Flux2Modulation(self.inner_dim, mod_param_sets=2, bias=False)
901
+ self.double_stream_modulation_txt = Flux2Modulation(self.inner_dim, mod_param_sets=2, bias=False)
902
+ # Only one set of modulation parameters as the attn and FF sub-blocks are run in parallel for single stream
903
+ self.single_stream_modulation = Flux2Modulation(self.inner_dim, mod_param_sets=1, bias=False)
904
+
905
+ # 4. Input projections
906
+ self.x_embedder = nn.Linear(in_channels, self.inner_dim, bias=False)
907
+ self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim, bias=False)
908
+
909
+ # 5. Double Stream Transformer Blocks
910
+ self.transformer_blocks = nn.ModuleList(
911
+ [
912
+ Flux2TransformerBlock(
913
+ dim=self.inner_dim,
914
+ num_attention_heads=num_attention_heads,
915
+ attention_head_dim=attention_head_dim,
916
+ mlp_ratio=mlp_ratio,
917
+ eps=eps,
918
+ bias=False,
919
+ )
920
+ for _ in range(num_layers)
921
+ ]
922
+ )
923
+
924
+ # 6. Single Stream Transformer Blocks
925
+ self.single_transformer_blocks = nn.ModuleList(
926
+ [
927
+ Flux2SingleTransformerBlock(
928
+ dim=self.inner_dim,
929
+ num_attention_heads=num_attention_heads,
930
+ attention_head_dim=attention_head_dim,
931
+ mlp_ratio=mlp_ratio,
932
+ eps=eps,
933
+ bias=False,
934
+ )
935
+ for _ in range(num_single_layers)
936
+ ]
937
+ )
938
+
939
+ # 7. Output layers
940
+ self.norm_out = AdaLayerNormContinuous(
941
+ self.inner_dim, self.inner_dim, elementwise_affine=False, eps=eps, bias=False
942
+ )
943
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=False)
944
+
945
+ self.gradient_checkpointing = False
946
+
947
+ def forward(
948
+ self,
949
+ hidden_states: torch.Tensor,
950
+ encoder_hidden_states: torch.Tensor = None,
951
+ timestep: torch.LongTensor = None,
952
+ img_ids: torch.Tensor = None,
953
+ txt_ids: torch.Tensor = None,
954
+ guidance: torch.Tensor = None,
955
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
956
+ return_dict: bool = True,
957
+ use_gradient_checkpointing=False,
958
+ use_gradient_checkpointing_offload=False,
959
+ ) -> Union[torch.Tensor]:
960
+ """
961
+ The [`FluxTransformer2DModel`] forward method.
962
+
963
+ Args:
964
+ hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
965
+ Input `hidden_states`.
966
+ encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
967
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
968
+ timestep ( `torch.LongTensor`):
969
+ Used to indicate denoising step.
970
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
971
+ A list of tensors that if specified are added to the residuals of transformer blocks.
972
+ joint_attention_kwargs (`dict`, *optional*):
973
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
974
+ `self.processor` in
975
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
976
+ return_dict (`bool`, *optional*, defaults to `True`):
977
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
978
+ tuple.
979
+
980
+ Returns:
981
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
982
+ `tuple` where the first element is the sample tensor.
983
+ """
984
+ # 0. Handle input arguments
985
+ if joint_attention_kwargs is not None:
986
+ joint_attention_kwargs = joint_attention_kwargs.copy()
987
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
988
+ else:
989
+ lora_scale = 1.0
990
+
991
+ num_txt_tokens = encoder_hidden_states.shape[1]
992
+
993
+ # 1. Calculate timestep embedding and modulation parameters
994
+ timestep = timestep.to(hidden_states.dtype) * 1000
995
+ guidance = guidance.to(hidden_states.dtype) * 1000
996
+
997
+ temb = self.time_guidance_embed(timestep, guidance)
998
+
999
+ double_stream_mod_img = self.double_stream_modulation_img(temb)
1000
+ double_stream_mod_txt = self.double_stream_modulation_txt(temb)
1001
+ single_stream_mod = self.single_stream_modulation(temb)[0]
1002
+
1003
+ # 2. Input projection for image (hidden_states) and conditioning text (encoder_hidden_states)
1004
+ hidden_states = self.x_embedder(hidden_states)
1005
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
1006
+
1007
+ # 3. Calculate RoPE embeddings from image and text tokens
1008
+ # NOTE: the below logic means that we can't support batched inference with images of different resolutions or
1009
+ # text prompts of differents lengths. Is this a use case we want to support?
1010
+ if img_ids.ndim == 3:
1011
+ img_ids = img_ids[0]
1012
+ if txt_ids.ndim == 3:
1013
+ txt_ids = txt_ids[0]
1014
+
1015
+ image_rotary_emb = self.pos_embed(img_ids)
1016
+ text_rotary_emb = self.pos_embed(txt_ids)
1017
+ concat_rotary_emb = (
1018
+ torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=0),
1019
+ torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0),
1020
+ )
1021
+
1022
+ # 4. Double Stream Transformer Blocks
1023
+ for index_block, block in enumerate(self.transformer_blocks):
1024
+ encoder_hidden_states, hidden_states = gradient_checkpoint_forward(
1025
+ block,
1026
+ use_gradient_checkpointing=use_gradient_checkpointing,
1027
+ use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
1028
+ hidden_states=hidden_states,
1029
+ encoder_hidden_states=encoder_hidden_states,
1030
+ temb_mod_params_img=double_stream_mod_img,
1031
+ temb_mod_params_txt=double_stream_mod_txt,
1032
+ image_rotary_emb=concat_rotary_emb,
1033
+ joint_attention_kwargs=joint_attention_kwargs,
1034
+ )
1035
+ # Concatenate text and image streams for single-block inference
1036
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
1037
+
1038
+ # 5. Single Stream Transformer Blocks
1039
+ for index_block, block in enumerate(self.single_transformer_blocks):
1040
+ hidden_states = gradient_checkpoint_forward(
1041
+ block,
1042
+ use_gradient_checkpointing=use_gradient_checkpointing,
1043
+ use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
1044
+ hidden_states=hidden_states,
1045
+ encoder_hidden_states=None,
1046
+ temb_mod_params=single_stream_mod,
1047
+ image_rotary_emb=concat_rotary_emb,
1048
+ joint_attention_kwargs=joint_attention_kwargs,
1049
+ )
1050
+ # Remove text tokens from concatenated stream
1051
+ hidden_states = hidden_states[:, num_txt_tokens:, ...]
1052
+
1053
+ # 6. Output layers
1054
+ hidden_states = self.norm_out(hidden_states, temb)
1055
+ output = self.proj_out(hidden_states)
1056
+
1057
+ return output
diffsynth/models/flux2_text_encoder.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Mistral3ForConditionalGeneration, Mistral3Config
2
+
3
+
4
+ class Flux2TextEncoder(Mistral3ForConditionalGeneration):
5
+ def __init__(self):
6
+ config = Mistral3Config(**{
7
+ "architectures": [
8
+ "Mistral3ForConditionalGeneration"
9
+ ],
10
+ "dtype": "bfloat16",
11
+ "image_token_index": 10,
12
+ "model_type": "mistral3",
13
+ "multimodal_projector_bias": False,
14
+ "projector_hidden_act": "gelu",
15
+ "spatial_merge_size": 2,
16
+ "text_config": {
17
+ "attention_dropout": 0.0,
18
+ "dtype": "bfloat16",
19
+ "head_dim": 128,
20
+ "hidden_act": "silu",
21
+ "hidden_size": 5120,
22
+ "initializer_range": 0.02,
23
+ "intermediate_size": 32768,
24
+ "max_position_embeddings": 131072,
25
+ "model_type": "mistral",
26
+ "num_attention_heads": 32,
27
+ "num_hidden_layers": 40,
28
+ "num_key_value_heads": 8,
29
+ "rms_norm_eps": 1e-05,
30
+ "rope_theta": 1000000000.0,
31
+ "sliding_window": None,
32
+ "use_cache": True,
33
+ "vocab_size": 131072
34
+ },
35
+ "transformers_version": "4.57.1",
36
+ "vision_config": {
37
+ "attention_dropout": 0.0,
38
+ "dtype": "bfloat16",
39
+ "head_dim": 64,
40
+ "hidden_act": "silu",
41
+ "hidden_size": 1024,
42
+ "image_size": 1540,
43
+ "initializer_range": 0.02,
44
+ "intermediate_size": 4096,
45
+ "model_type": "pixtral",
46
+ "num_attention_heads": 16,
47
+ "num_channels": 3,
48
+ "num_hidden_layers": 24,
49
+ "patch_size": 14,
50
+ "rope_theta": 10000.0
51
+ },
52
+ "vision_feature_layer": -1
53
+ })
54
+ super().__init__(config)
55
+
56
+ def forward(self, input_ids = None, pixel_values = None, attention_mask = None, position_ids = None, past_key_values = None, inputs_embeds = None, labels = None, use_cache = None, output_attentions = None, output_hidden_states = None, return_dict = None, cache_position = None, logits_to_keep = 0, image_sizes = None, **kwargs):
57
+ return super().forward(input_ids, pixel_values, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, logits_to_keep, image_sizes, **kwargs)
58
+
diffsynth/models/flux2_vae.py ADDED
The diff for this file is too large to render. See raw diff
 
diffsynth/models/flux_controlnet.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from einops import rearrange, repeat
3
+ from .flux_dit import RoPEEmbedding, TimestepEmbeddings, FluxJointTransformerBlock, FluxSingleTransformerBlock, RMSNorm
4
+ # from .utils import hash_state_dict_keys, init_weights_on_device
5
+ from contextlib import contextmanager
6
+
7
+ def hash_state_dict_keys(state_dict, with_shape=True):
8
+ keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
9
+ keys_str = keys_str.encode(encoding="UTF-8")
10
+ return hashlib.md5(keys_str).hexdigest()
11
+
12
+ @contextmanager
13
+ def init_weights_on_device(device = torch.device("meta"), include_buffers :bool = False):
14
+
15
+ old_register_parameter = torch.nn.Module.register_parameter
16
+ if include_buffers:
17
+ old_register_buffer = torch.nn.Module.register_buffer
18
+
19
+ def register_empty_parameter(module, name, param):
20
+ old_register_parameter(module, name, param)
21
+ if param is not None:
22
+ param_cls = type(module._parameters[name])
23
+ kwargs = module._parameters[name].__dict__
24
+ kwargs["requires_grad"] = param.requires_grad
25
+ module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
26
+
27
+ def register_empty_buffer(module, name, buffer, persistent=True):
28
+ old_register_buffer(module, name, buffer, persistent=persistent)
29
+ if buffer is not None:
30
+ module._buffers[name] = module._buffers[name].to(device)
31
+
32
+ def patch_tensor_constructor(fn):
33
+ def wrapper(*args, **kwargs):
34
+ kwargs["device"] = device
35
+ return fn(*args, **kwargs)
36
+
37
+ return wrapper
38
+
39
+ if include_buffers:
40
+ tensor_constructors_to_patch = {
41
+ torch_function_name: getattr(torch, torch_function_name)
42
+ for torch_function_name in ["empty", "zeros", "ones", "full"]
43
+ }
44
+ else:
45
+ tensor_constructors_to_patch = {}
46
+
47
+ try:
48
+ torch.nn.Module.register_parameter = register_empty_parameter
49
+ if include_buffers:
50
+ torch.nn.Module.register_buffer = register_empty_buffer
51
+ for torch_function_name in tensor_constructors_to_patch.keys():
52
+ setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
53
+ yield
54
+ finally:
55
+ torch.nn.Module.register_parameter = old_register_parameter
56
+ if include_buffers:
57
+ torch.nn.Module.register_buffer = old_register_buffer
58
+ for torch_function_name, old_torch_function in tensor_constructors_to_patch.items():
59
+ setattr(torch, torch_function_name, old_torch_function)
60
+
61
+ class FluxControlNet(torch.nn.Module):
62
+ def __init__(self, disable_guidance_embedder=False, num_joint_blocks=5, num_single_blocks=10, num_mode=0, mode_dict={}, additional_input_dim=0):
63
+ super().__init__()
64
+ self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
65
+ self.time_embedder = TimestepEmbeddings(256, 3072)
66
+ self.guidance_embedder = None if disable_guidance_embedder else TimestepEmbeddings(256, 3072)
67
+ self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072))
68
+ self.context_embedder = torch.nn.Linear(4096, 3072)
69
+ self.x_embedder = torch.nn.Linear(64, 3072)
70
+
71
+ self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(num_joint_blocks)])
72
+ self.single_blocks = torch.nn.ModuleList([FluxSingleTransformerBlock(3072, 24) for _ in range(num_single_blocks)])
73
+
74
+ self.controlnet_blocks = torch.nn.ModuleList([torch.nn.Linear(3072, 3072) for _ in range(num_joint_blocks)])
75
+ self.controlnet_single_blocks = torch.nn.ModuleList([torch.nn.Linear(3072, 3072) for _ in range(num_single_blocks)])
76
+
77
+ self.mode_dict = mode_dict
78
+ self.controlnet_mode_embedder = torch.nn.Embedding(num_mode, 3072) if len(mode_dict) > 0 else None
79
+ self.controlnet_x_embedder = torch.nn.Linear(64 + additional_input_dim, 3072)
80
+
81
+
82
+ def prepare_image_ids(self, latents):
83
+ batch_size, _, height, width = latents.shape
84
+ latent_image_ids = torch.zeros(height // 2, width // 2, 3)
85
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
86
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
87
+
88
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
89
+
90
+ latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
91
+ latent_image_ids = latent_image_ids.reshape(
92
+ batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
93
+ )
94
+ latent_image_ids = latent_image_ids.to(device=latents.device, dtype=latents.dtype)
95
+
96
+ return latent_image_ids
97
+
98
+
99
+ def patchify(self, hidden_states):
100
+ hidden_states = rearrange(hidden_states, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2)
101
+ return hidden_states
102
+
103
+
104
+ def align_res_stack_to_original_blocks(self, res_stack, num_blocks, hidden_states):
105
+ if len(res_stack) == 0:
106
+ return [torch.zeros_like(hidden_states)] * num_blocks
107
+ interval = (num_blocks + len(res_stack) - 1) // len(res_stack)
108
+ aligned_res_stack = [res_stack[block_id // interval] for block_id in range(num_blocks)]
109
+ return aligned_res_stack
110
+
111
+
112
+ def forward(
113
+ self,
114
+ hidden_states,
115
+ controlnet_conditioning,
116
+ timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None,
117
+ processor_id=None,
118
+ tiled=False, tile_size=128, tile_stride=64,
119
+ **kwargs
120
+ ):
121
+ if image_ids is None:
122
+ image_ids = self.prepare_image_ids(hidden_states)
123
+
124
+ conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb)
125
+ if self.guidance_embedder is not None:
126
+ guidance = guidance * 1000
127
+ conditioning = conditioning + self.guidance_embedder(guidance, hidden_states.dtype)
128
+ prompt_emb = self.context_embedder(prompt_emb)
129
+ if self.controlnet_mode_embedder is not None: # Different from FluxDiT
130
+ processor_id = torch.tensor([self.mode_dict[processor_id]], dtype=torch.int)
131
+ processor_id = repeat(processor_id, "D -> B D", B=1).to(text_ids.device)
132
+ prompt_emb = torch.concat([self.controlnet_mode_embedder(processor_id), prompt_emb], dim=1)
133
+ text_ids = torch.cat([text_ids[:, :1], text_ids], dim=1)
134
+ image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
135
+
136
+ hidden_states = self.patchify(hidden_states)
137
+ hidden_states = self.x_embedder(hidden_states)
138
+ controlnet_conditioning = self.patchify(controlnet_conditioning) # Different from FluxDiT
139
+ hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_conditioning) # Different from FluxDiT
140
+
141
+ controlnet_res_stack = []
142
+ for block, controlnet_block in zip(self.blocks, self.controlnet_blocks):
143
+ hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
144
+ controlnet_res_stack.append(controlnet_block(hidden_states))
145
+
146
+ controlnet_single_res_stack = []
147
+ hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
148
+ for block, controlnet_block in zip(self.single_blocks, self.controlnet_single_blocks):
149
+ hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
150
+ controlnet_single_res_stack.append(controlnet_block(hidden_states[:, prompt_emb.shape[1]:]))
151
+
152
+ controlnet_res_stack = self.align_res_stack_to_original_blocks(controlnet_res_stack, 19, hidden_states[:, prompt_emb.shape[1]:])
153
+ controlnet_single_res_stack = self.align_res_stack_to_original_blocks(controlnet_single_res_stack, 38, hidden_states[:, prompt_emb.shape[1]:])
154
+
155
+ return controlnet_res_stack, controlnet_single_res_stack
156
+
157
+
158
+ # @staticmethod
159
+ # def state_dict_converter():
160
+ # return FluxControlNetStateDictConverter()
161
+
162
+ def quantize(self):
163
+ def cast_to(weight, dtype=None, device=None, copy=False):
164
+ if device is None or weight.device == device:
165
+ if not copy:
166
+ if dtype is None or weight.dtype == dtype:
167
+ return weight
168
+ return weight.to(dtype=dtype, copy=copy)
169
+
170
+ r = torch.empty_like(weight, dtype=dtype, device=device)
171
+ r.copy_(weight)
172
+ return r
173
+
174
+ def cast_weight(s, input=None, dtype=None, device=None):
175
+ if input is not None:
176
+ if dtype is None:
177
+ dtype = input.dtype
178
+ if device is None:
179
+ device = input.device
180
+ weight = cast_to(s.weight, dtype, device)
181
+ return weight
182
+
183
+ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
184
+ if input is not None:
185
+ if dtype is None:
186
+ dtype = input.dtype
187
+ if bias_dtype is None:
188
+ bias_dtype = dtype
189
+ if device is None:
190
+ device = input.device
191
+ bias = None
192
+ weight = cast_to(s.weight, dtype, device)
193
+ bias = cast_to(s.bias, bias_dtype, device)
194
+ return weight, bias
195
+
196
+ class quantized_layer:
197
+ class QLinear(torch.nn.Linear):
198
+ def __init__(self, *args, **kwargs):
199
+ super().__init__(*args, **kwargs)
200
+
201
+ def forward(self,input,**kwargs):
202
+ weight,bias= cast_bias_weight(self,input)
203
+ return torch.nn.functional.linear(input,weight,bias)
204
+
205
+ class QRMSNorm(torch.nn.Module):
206
+ def __init__(self, module):
207
+ super().__init__()
208
+ self.module = module
209
+
210
+ def forward(self,hidden_states,**kwargs):
211
+ weight= cast_weight(self.module,hidden_states)
212
+ input_dtype = hidden_states.dtype
213
+ variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
214
+ hidden_states = hidden_states * torch.rsqrt(variance + self.module.eps)
215
+ hidden_states = hidden_states.to(input_dtype) * weight
216
+ return hidden_states
217
+
218
+ class QEmbedding(torch.nn.Embedding):
219
+ def __init__(self, *args, **kwargs):
220
+ super().__init__(*args, **kwargs)
221
+
222
+ def forward(self,input,**kwargs):
223
+ weight= cast_weight(self,input)
224
+ return torch.nn.functional.embedding(
225
+ input, weight, self.padding_idx, self.max_norm,
226
+ self.norm_type, self.scale_grad_by_freq, self.sparse)
227
+
228
+ def replace_layer(model):
229
+ for name, module in model.named_children():
230
+ if isinstance(module,quantized_layer.QRMSNorm):
231
+ continue
232
+ if isinstance(module, torch.nn.Linear):
233
+ with init_weights_on_device():
234
+ new_layer = quantized_layer.QLinear(module.in_features,module.out_features)
235
+ new_layer.weight = module.weight
236
+ if module.bias is not None:
237
+ new_layer.bias = module.bias
238
+ setattr(model, name, new_layer)
239
+ elif isinstance(module, RMSNorm):
240
+ if hasattr(module,"quantized"):
241
+ continue
242
+ module.quantized= True
243
+ new_layer = quantized_layer.QRMSNorm(module)
244
+ setattr(model, name, new_layer)
245
+ elif isinstance(module,torch.nn.Embedding):
246
+ rows, cols = module.weight.shape
247
+ new_layer = quantized_layer.QEmbedding(
248
+ num_embeddings=rows,
249
+ embedding_dim=cols,
250
+ _weight=module.weight,
251
+ # _freeze=module.freeze,
252
+ padding_idx=module.padding_idx,
253
+ max_norm=module.max_norm,
254
+ norm_type=module.norm_type,
255
+ scale_grad_by_freq=module.scale_grad_by_freq,
256
+ sparse=module.sparse)
257
+ setattr(model, name, new_layer)
258
+ else:
259
+ replace_layer(module)
260
+
261
+ replace_layer(self)
262
+
263
+
264
+
265
+ class FluxControlNetStateDictConverter:
266
+ def __init__(self):
267
+ pass
268
+
269
+ def from_diffusers(self, state_dict):
270
+ hash_value = hash_state_dict_keys(state_dict)
271
+ global_rename_dict = {
272
+ "context_embedder": "context_embedder",
273
+ "x_embedder": "x_embedder",
274
+ "time_text_embed.timestep_embedder.linear_1": "time_embedder.timestep_embedder.0",
275
+ "time_text_embed.timestep_embedder.linear_2": "time_embedder.timestep_embedder.2",
276
+ "time_text_embed.guidance_embedder.linear_1": "guidance_embedder.timestep_embedder.0",
277
+ "time_text_embed.guidance_embedder.linear_2": "guidance_embedder.timestep_embedder.2",
278
+ "time_text_embed.text_embedder.linear_1": "pooled_text_embedder.0",
279
+ "time_text_embed.text_embedder.linear_2": "pooled_text_embedder.2",
280
+ "norm_out.linear": "final_norm_out.linear",
281
+ "proj_out": "final_proj_out",
282
+ }
283
+ rename_dict = {
284
+ "proj_out": "proj_out",
285
+ "norm1.linear": "norm1_a.linear",
286
+ "norm1_context.linear": "norm1_b.linear",
287
+ "attn.to_q": "attn.a_to_q",
288
+ "attn.to_k": "attn.a_to_k",
289
+ "attn.to_v": "attn.a_to_v",
290
+ "attn.to_out.0": "attn.a_to_out",
291
+ "attn.add_q_proj": "attn.b_to_q",
292
+ "attn.add_k_proj": "attn.b_to_k",
293
+ "attn.add_v_proj": "attn.b_to_v",
294
+ "attn.to_add_out": "attn.b_to_out",
295
+ "ff.net.0.proj": "ff_a.0",
296
+ "ff.net.2": "ff_a.2",
297
+ "ff_context.net.0.proj": "ff_b.0",
298
+ "ff_context.net.2": "ff_b.2",
299
+ "attn.norm_q": "attn.norm_q_a",
300
+ "attn.norm_k": "attn.norm_k_a",
301
+ "attn.norm_added_q": "attn.norm_q_b",
302
+ "attn.norm_added_k": "attn.norm_k_b",
303
+ }
304
+ rename_dict_single = {
305
+ "attn.to_q": "a_to_q",
306
+ "attn.to_k": "a_to_k",
307
+ "attn.to_v": "a_to_v",
308
+ "attn.norm_q": "norm_q_a",
309
+ "attn.norm_k": "norm_k_a",
310
+ "norm.linear": "norm.linear",
311
+ "proj_mlp": "proj_in_besides_attn",
312
+ "proj_out": "proj_out",
313
+ }
314
+ state_dict_ = {}
315
+ for name, param in state_dict.items():
316
+ if name.endswith(".weight") or name.endswith(".bias"):
317
+ suffix = ".weight" if name.endswith(".weight") else ".bias"
318
+ prefix = name[:-len(suffix)]
319
+ if prefix in global_rename_dict:
320
+ state_dict_[global_rename_dict[prefix] + suffix] = param
321
+ elif prefix.startswith("transformer_blocks."):
322
+ names = prefix.split(".")
323
+ names[0] = "blocks"
324
+ middle = ".".join(names[2:])
325
+ if middle in rename_dict:
326
+ name_ = ".".join(names[:2] + [rename_dict[middle]] + [suffix[1:]])
327
+ state_dict_[name_] = param
328
+ elif prefix.startswith("single_transformer_blocks."):
329
+ names = prefix.split(".")
330
+ names[0] = "single_blocks"
331
+ middle = ".".join(names[2:])
332
+ if middle in rename_dict_single:
333
+ name_ = ".".join(names[:2] + [rename_dict_single[middle]] + [suffix[1:]])
334
+ state_dict_[name_] = param
335
+ else:
336
+ state_dict_[name] = param
337
+ else:
338
+ state_dict_[name] = param
339
+ for name in list(state_dict_.keys()):
340
+ if ".proj_in_besides_attn." in name:
341
+ name_ = name.replace(".proj_in_besides_attn.", ".to_qkv_mlp.")
342
+ param = torch.concat([
343
+ state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_q.")],
344
+ state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_k.")],
345
+ state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_v.")],
346
+ state_dict_[name],
347
+ ], dim=0)
348
+ state_dict_[name_] = param
349
+ state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_q."))
350
+ state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_k."))
351
+ state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_v."))
352
+ state_dict_.pop(name)
353
+ for name in list(state_dict_.keys()):
354
+ for component in ["a", "b"]:
355
+ if f".{component}_to_q." in name:
356
+ name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.")
357
+ param = torch.concat([
358
+ state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
359
+ state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
360
+ state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
361
+ ], dim=0)
362
+ state_dict_[name_] = param
363
+ state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q."))
364
+ state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k."))
365
+ state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v."))
366
+ if hash_value == "78d18b9101345ff695f312e7e62538c0":
367
+ extra_kwargs = {"num_mode": 10, "mode_dict": {"canny": 0, "tile": 1, "depth": 2, "blur": 3, "pose": 4, "gray": 5, "lq": 6}}
368
+ elif hash_value == "b001c89139b5f053c715fe772362dd2a":
369
+ extra_kwargs = {"num_single_blocks": 0}
370
+ elif hash_value == "52357cb26250681367488a8954c271e8":
371
+ extra_kwargs = {"num_joint_blocks": 6, "num_single_blocks": 0, "additional_input_dim": 4}
372
+ elif hash_value == "0cfd1740758423a2a854d67c136d1e8c":
373
+ extra_kwargs = {"num_joint_blocks": 4, "num_single_blocks": 1}
374
+ elif hash_value == "7f9583eb8ba86642abb9a21a4b2c9e16":
375
+ extra_kwargs = {"num_joint_blocks": 4, "num_single_blocks": 10}
376
+ elif hash_value == "43ad5aaa27dd4ee01b832ed16773fa52":
377
+ extra_kwargs = {"num_joint_blocks": 6, "num_single_blocks": 0}
378
+ else:
379
+ extra_kwargs = {}
380
+ return state_dict_, extra_kwargs
381
+
382
+
383
+ def from_civitai(self, state_dict):
384
+ return self.from_diffusers(state_dict)
diffsynth/models/flux_dit.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .general_modules import TimestepEmbeddings, AdaLayerNorm, RMSNorm
3
+ from einops import rearrange
4
+
5
+
6
+ def interact_with_ipadapter(hidden_states, q, ip_k, ip_v, scale=1.0):
7
+ batch_size, num_tokens = hidden_states.shape[0:2]
8
+ ip_hidden_states = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v)
9
+ ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, num_tokens, -1)
10
+ hidden_states = hidden_states + scale * ip_hidden_states
11
+ return hidden_states
12
+
13
+
14
+ class RoPEEmbedding(torch.nn.Module):
15
+ def __init__(self, dim, theta, axes_dim):
16
+ super().__init__()
17
+ self.dim = dim
18
+ self.theta = theta
19
+ self.axes_dim = axes_dim
20
+
21
+
22
+ def rope(self, pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
23
+ assert dim % 2 == 0, "The dimension must be even."
24
+
25
+ scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
26
+ omega = 1.0 / (theta**scale)
27
+
28
+ batch_size, seq_length = pos.shape
29
+ out = torch.einsum("...n,d->...nd", pos, omega)
30
+ cos_out = torch.cos(out)
31
+ sin_out = torch.sin(out)
32
+
33
+ stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
34
+ out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
35
+ return out.float()
36
+
37
+
38
+ def forward(self, ids):
39
+ n_axes = ids.shape[-1]
40
+ emb = torch.cat([self.rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3)
41
+ return emb.unsqueeze(1)
42
+
43
+
44
+
45
+ class FluxJointAttention(torch.nn.Module):
46
+ def __init__(self, dim_a, dim_b, num_heads, head_dim, only_out_a=False):
47
+ super().__init__()
48
+ self.num_heads = num_heads
49
+ self.head_dim = head_dim
50
+ self.only_out_a = only_out_a
51
+
52
+ self.a_to_qkv = torch.nn.Linear(dim_a, dim_a * 3)
53
+ self.b_to_qkv = torch.nn.Linear(dim_b, dim_b * 3)
54
+
55
+ self.norm_q_a = RMSNorm(head_dim, eps=1e-6)
56
+ self.norm_k_a = RMSNorm(head_dim, eps=1e-6)
57
+ self.norm_q_b = RMSNorm(head_dim, eps=1e-6)
58
+ self.norm_k_b = RMSNorm(head_dim, eps=1e-6)
59
+
60
+ self.a_to_out = torch.nn.Linear(dim_a, dim_a)
61
+ if not only_out_a:
62
+ self.b_to_out = torch.nn.Linear(dim_b, dim_b)
63
+
64
+
65
+ def apply_rope(self, xq, xk, freqs_cis):
66
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
67
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
68
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
69
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
70
+ return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
71
+
72
+ def forward(self, hidden_states_a, hidden_states_b, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
73
+ batch_size = hidden_states_a.shape[0]
74
+
75
+ # Part A
76
+ qkv_a = self.a_to_qkv(hidden_states_a)
77
+ qkv_a = qkv_a.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
78
+ q_a, k_a, v_a = qkv_a.chunk(3, dim=1)
79
+ q_a, k_a = self.norm_q_a(q_a), self.norm_k_a(k_a)
80
+
81
+ # Part B
82
+ qkv_b = self.b_to_qkv(hidden_states_b)
83
+ qkv_b = qkv_b.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
84
+ q_b, k_b, v_b = qkv_b.chunk(3, dim=1)
85
+ q_b, k_b = self.norm_q_b(q_b), self.norm_k_b(k_b)
86
+
87
+ q = torch.concat([q_b, q_a], dim=2)
88
+ k = torch.concat([k_b, k_a], dim=2)
89
+ v = torch.concat([v_b, v_a], dim=2)
90
+
91
+ q, k = self.apply_rope(q, k, image_rotary_emb)
92
+
93
+ hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
94
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
95
+ hidden_states = hidden_states.to(q.dtype)
96
+ hidden_states_b, hidden_states_a = hidden_states[:, :hidden_states_b.shape[1]], hidden_states[:, hidden_states_b.shape[1]:]
97
+ if ipadapter_kwargs_list is not None:
98
+ hidden_states_a = interact_with_ipadapter(hidden_states_a, q_a, **ipadapter_kwargs_list)
99
+ hidden_states_a = self.a_to_out(hidden_states_a)
100
+ if self.only_out_a:
101
+ return hidden_states_a
102
+ else:
103
+ hidden_states_b = self.b_to_out(hidden_states_b)
104
+ return hidden_states_a, hidden_states_b
105
+
106
+
107
+
108
+ class FluxJointTransformerBlock(torch.nn.Module):
109
+ def __init__(self, dim, num_attention_heads):
110
+ super().__init__()
111
+ self.norm1_a = AdaLayerNorm(dim)
112
+ self.norm1_b = AdaLayerNorm(dim)
113
+
114
+ self.attn = FluxJointAttention(dim, dim, num_attention_heads, dim // num_attention_heads)
115
+
116
+ self.norm2_a = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
117
+ self.ff_a = torch.nn.Sequential(
118
+ torch.nn.Linear(dim, dim*4),
119
+ torch.nn.GELU(approximate="tanh"),
120
+ torch.nn.Linear(dim*4, dim)
121
+ )
122
+
123
+ self.norm2_b = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
124
+ self.ff_b = torch.nn.Sequential(
125
+ torch.nn.Linear(dim, dim*4),
126
+ torch.nn.GELU(approximate="tanh"),
127
+ torch.nn.Linear(dim*4, dim)
128
+ )
129
+
130
+
131
+ def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
132
+ norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb)
133
+ norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb)
134
+
135
+ # Attention
136
+ attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b, image_rotary_emb, attn_mask, ipadapter_kwargs_list)
137
+
138
+ # Part A
139
+ hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
140
+ norm_hidden_states_a = self.norm2_a(hidden_states_a) * (1 + scale_mlp_a) + shift_mlp_a
141
+ hidden_states_a = hidden_states_a + gate_mlp_a * self.ff_a(norm_hidden_states_a)
142
+
143
+ # Part B
144
+ hidden_states_b = hidden_states_b + gate_msa_b * attn_output_b
145
+ norm_hidden_states_b = self.norm2_b(hidden_states_b) * (1 + scale_mlp_b) + shift_mlp_b
146
+ hidden_states_b = hidden_states_b + gate_mlp_b * self.ff_b(norm_hidden_states_b)
147
+
148
+ return hidden_states_a, hidden_states_b
149
+
150
+
151
+
152
+ class FluxSingleAttention(torch.nn.Module):
153
+ def __init__(self, dim_a, dim_b, num_heads, head_dim):
154
+ super().__init__()
155
+ self.num_heads = num_heads
156
+ self.head_dim = head_dim
157
+
158
+ self.a_to_qkv = torch.nn.Linear(dim_a, dim_a * 3)
159
+
160
+ self.norm_q_a = RMSNorm(head_dim, eps=1e-6)
161
+ self.norm_k_a = RMSNorm(head_dim, eps=1e-6)
162
+
163
+
164
+ def apply_rope(self, xq, xk, freqs_cis):
165
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
166
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
167
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
168
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
169
+ return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
170
+
171
+
172
+ def forward(self, hidden_states, image_rotary_emb):
173
+ batch_size = hidden_states.shape[0]
174
+
175
+ qkv_a = self.a_to_qkv(hidden_states)
176
+ qkv_a = qkv_a.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
177
+ q_a, k_a, v = qkv_a.chunk(3, dim=1)
178
+ q_a, k_a = self.norm_q_a(q_a), self.norm_k_a(k_a)
179
+
180
+ q, k = self.apply_rope(q_a, k_a, image_rotary_emb)
181
+
182
+ hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v)
183
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
184
+ hidden_states = hidden_states.to(q.dtype)
185
+ return hidden_states
186
+
187
+
188
+
189
+ class AdaLayerNormSingle(torch.nn.Module):
190
+ def __init__(self, dim):
191
+ super().__init__()
192
+ self.silu = torch.nn.SiLU()
193
+ self.linear = torch.nn.Linear(dim, 3 * dim, bias=True)
194
+ self.norm = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
195
+
196
+
197
+ def forward(self, x, emb):
198
+ emb = self.linear(self.silu(emb))
199
+ shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1)
200
+ x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
201
+ return x, gate_msa
202
+
203
+
204
+
205
+ class FluxSingleTransformerBlock(torch.nn.Module):
206
+ def __init__(self, dim, num_attention_heads):
207
+ super().__init__()
208
+ self.num_heads = num_attention_heads
209
+ self.head_dim = dim // num_attention_heads
210
+ self.dim = dim
211
+
212
+ self.norm = AdaLayerNormSingle(dim)
213
+ self.to_qkv_mlp = torch.nn.Linear(dim, dim * (3 + 4))
214
+ self.norm_q_a = RMSNorm(self.head_dim, eps=1e-6)
215
+ self.norm_k_a = RMSNorm(self.head_dim, eps=1e-6)
216
+
217
+ self.proj_out = torch.nn.Linear(dim * 5, dim)
218
+
219
+
220
+ def apply_rope(self, xq, xk, freqs_cis):
221
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
222
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
223
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
224
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
225
+ return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
226
+
227
+
228
+ def process_attention(self, hidden_states, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
229
+ batch_size = hidden_states.shape[0]
230
+
231
+ qkv = hidden_states.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
232
+ q, k, v = qkv.chunk(3, dim=1)
233
+ q, k = self.norm_q_a(q), self.norm_k_a(k)
234
+
235
+ q, k = self.apply_rope(q, k, image_rotary_emb)
236
+
237
+ hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
238
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
239
+ hidden_states = hidden_states.to(q.dtype)
240
+ if ipadapter_kwargs_list is not None:
241
+ hidden_states = interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs_list)
242
+ return hidden_states
243
+
244
+
245
+ def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
246
+ residual = hidden_states_a
247
+ norm_hidden_states, gate = self.norm(hidden_states_a, emb=temb)
248
+ hidden_states_a = self.to_qkv_mlp(norm_hidden_states)
249
+ attn_output, mlp_hidden_states = hidden_states_a[:, :, :self.dim * 3], hidden_states_a[:, :, self.dim * 3:]
250
+
251
+ attn_output = self.process_attention(attn_output, image_rotary_emb, attn_mask, ipadapter_kwargs_list)
252
+ mlp_hidden_states = torch.nn.functional.gelu(mlp_hidden_states, approximate="tanh")
253
+
254
+ hidden_states_a = torch.cat([attn_output, mlp_hidden_states], dim=2)
255
+ hidden_states_a = gate.unsqueeze(1) * self.proj_out(hidden_states_a)
256
+ hidden_states_a = residual + hidden_states_a
257
+
258
+ return hidden_states_a, hidden_states_b
259
+
260
+
261
+
262
+ class AdaLayerNormContinuous(torch.nn.Module):
263
+ def __init__(self, dim):
264
+ super().__init__()
265
+ self.silu = torch.nn.SiLU()
266
+ self.linear = torch.nn.Linear(dim, dim * 2, bias=True)
267
+ self.norm = torch.nn.LayerNorm(dim, eps=1e-6, elementwise_affine=False)
268
+
269
+ def forward(self, x, conditioning):
270
+ emb = self.linear(self.silu(conditioning))
271
+ shift, scale = torch.chunk(emb, 2, dim=1)
272
+ x = self.norm(x) * (1 + scale)[:, None] + shift[:, None]
273
+ return x
274
+
275
+
276
+
277
+ class FluxDiT(torch.nn.Module):
278
+ def __init__(self, disable_guidance_embedder=False, input_dim=64, num_blocks=19):
279
+ super().__init__()
280
+ self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
281
+ self.time_embedder = TimestepEmbeddings(256, 3072)
282
+ self.guidance_embedder = None if disable_guidance_embedder else TimestepEmbeddings(256, 3072)
283
+ self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072))
284
+ self.context_embedder = torch.nn.Linear(4096, 3072)
285
+ self.x_embedder = torch.nn.Linear(input_dim, 3072)
286
+
287
+ self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(num_blocks)])
288
+ self.single_blocks = torch.nn.ModuleList([FluxSingleTransformerBlock(3072, 24) for _ in range(38)])
289
+
290
+ self.final_norm_out = AdaLayerNormContinuous(3072)
291
+ self.final_proj_out = torch.nn.Linear(3072, 64)
292
+
293
+ self.input_dim = input_dim
294
+
295
+
296
+ def patchify(self, hidden_states):
297
+ hidden_states = rearrange(hidden_states, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2)
298
+ return hidden_states
299
+
300
+
301
+ def unpatchify(self, hidden_states, height, width):
302
+ hidden_states = rearrange(hidden_states, "B (H W) (C P Q) -> B C (H P) (W Q)", P=2, Q=2, H=height//2, W=width//2)
303
+ return hidden_states
304
+
305
+
306
+ def prepare_image_ids(self, latents):
307
+ batch_size, _, height, width = latents.shape
308
+ latent_image_ids = torch.zeros(height // 2, width // 2, 3)
309
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
310
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
311
+
312
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
313
+
314
+ latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
315
+ latent_image_ids = latent_image_ids.reshape(
316
+ batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
317
+ )
318
+ latent_image_ids = latent_image_ids.to(device=latents.device, dtype=latents.dtype)
319
+
320
+ return latent_image_ids
321
+
322
+
323
+ def construct_mask(self, entity_masks, prompt_seq_len, image_seq_len):
324
+ N = len(entity_masks)
325
+ batch_size = entity_masks[0].shape[0]
326
+ total_seq_len = N * prompt_seq_len + image_seq_len
327
+ patched_masks = [self.patchify(entity_masks[i]) for i in range(N)]
328
+ attention_mask = torch.ones((batch_size, total_seq_len, total_seq_len), dtype=torch.bool).to(device=entity_masks[0].device)
329
+
330
+ image_start = N * prompt_seq_len
331
+ image_end = N * prompt_seq_len + image_seq_len
332
+ # prompt-image mask
333
+ for i in range(N):
334
+ prompt_start = i * prompt_seq_len
335
+ prompt_end = (i + 1) * prompt_seq_len
336
+ image_mask = torch.sum(patched_masks[i], dim=-1) > 0
337
+ image_mask = image_mask.unsqueeze(1).repeat(1, prompt_seq_len, 1)
338
+ # prompt update with image
339
+ attention_mask[:, prompt_start:prompt_end, image_start:image_end] = image_mask
340
+ # image update with prompt
341
+ attention_mask[:, image_start:image_end, prompt_start:prompt_end] = image_mask.transpose(1, 2)
342
+ # prompt-prompt mask
343
+ for i in range(N):
344
+ for j in range(N):
345
+ if i != j:
346
+ prompt_start_i = i * prompt_seq_len
347
+ prompt_end_i = (i + 1) * prompt_seq_len
348
+ prompt_start_j = j * prompt_seq_len
349
+ prompt_end_j = (j + 1) * prompt_seq_len
350
+ attention_mask[:, prompt_start_i:prompt_end_i, prompt_start_j:prompt_end_j] = False
351
+
352
+ attention_mask = attention_mask.float()
353
+ attention_mask[attention_mask == 0] = float('-inf')
354
+ attention_mask[attention_mask == 1] = 0
355
+ return attention_mask
356
+
357
+
358
+ def process_entity_masks(self, hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids, repeat_dim):
359
+ max_masks = 0
360
+ attention_mask = None
361
+ prompt_embs = [prompt_emb]
362
+ if entity_masks is not None:
363
+ # entity_masks
364
+ batch_size, max_masks = entity_masks.shape[0], entity_masks.shape[1]
365
+ entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1)
366
+ entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)]
367
+ # global mask
368
+ global_mask = torch.ones_like(entity_masks[0]).to(device=hidden_states.device, dtype=hidden_states.dtype)
369
+ entity_masks = entity_masks + [global_mask] # append global to last
370
+ # attention mask
371
+ attention_mask = self.construct_mask(entity_masks, prompt_emb.shape[1], hidden_states.shape[1])
372
+ attention_mask = attention_mask.to(device=hidden_states.device, dtype=hidden_states.dtype)
373
+ attention_mask = attention_mask.unsqueeze(1)
374
+ # embds: n_masks * b * seq * d
375
+ local_embs = [entity_prompt_emb[:, i, None].squeeze(1) for i in range(max_masks)]
376
+ prompt_embs = local_embs + prompt_embs # append global to last
377
+ prompt_embs = [self.context_embedder(prompt_emb) for prompt_emb in prompt_embs]
378
+ prompt_emb = torch.cat(prompt_embs, dim=1)
379
+
380
+ # positional embedding
381
+ text_ids = torch.cat([text_ids] * (max_masks + 1), dim=1)
382
+ image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
383
+ return prompt_emb, image_rotary_emb, attention_mask
384
+
385
+
386
+ def forward(
387
+ self,
388
+ hidden_states,
389
+ timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None,
390
+ tiled=False, tile_size=128, tile_stride=64, entity_prompt_emb=None, entity_masks=None,
391
+ use_gradient_checkpointing=False,
392
+ **kwargs
393
+ ):
394
+ # (Deprecated) The real forward is in `pipelines.flux_image`.
395
+ return None
diffsynth/models/flux_infiniteyou.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+
6
+ # FFN
7
+ def FeedForward(dim, mult=4):
8
+ inner_dim = int(dim * mult)
9
+ return nn.Sequential(
10
+ nn.LayerNorm(dim),
11
+ nn.Linear(dim, inner_dim, bias=False),
12
+ nn.GELU(),
13
+ nn.Linear(inner_dim, dim, bias=False),
14
+ )
15
+
16
+
17
+ def reshape_tensor(x, heads):
18
+ bs, length, width = x.shape
19
+ #(bs, length, width) --> (bs, length, n_heads, dim_per_head)
20
+ x = x.view(bs, length, heads, -1)
21
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
22
+ x = x.transpose(1, 2)
23
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
24
+ x = x.reshape(bs, heads, length, -1)
25
+ return x
26
+
27
+
28
+ class PerceiverAttention(nn.Module):
29
+
30
+ def __init__(self, *, dim, dim_head=64, heads=8):
31
+ super().__init__()
32
+ self.scale = dim_head**-0.5
33
+ self.dim_head = dim_head
34
+ self.heads = heads
35
+ inner_dim = dim_head * heads
36
+
37
+ self.norm1 = nn.LayerNorm(dim)
38
+ self.norm2 = nn.LayerNorm(dim)
39
+
40
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
41
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
42
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
43
+
44
+ def forward(self, x, latents):
45
+ """
46
+ Args:
47
+ x (torch.Tensor): image features
48
+ shape (b, n1, D)
49
+ latent (torch.Tensor): latent features
50
+ shape (b, n2, D)
51
+ """
52
+ x = self.norm1(x)
53
+ latents = self.norm2(latents)
54
+
55
+ b, l, _ = latents.shape
56
+
57
+ q = self.to_q(latents)
58
+ kv_input = torch.cat((x, latents), dim=-2)
59
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
60
+
61
+ q = reshape_tensor(q, self.heads)
62
+ k = reshape_tensor(k, self.heads)
63
+ v = reshape_tensor(v, self.heads)
64
+
65
+ # attention
66
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
67
+ weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
68
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
69
+ out = weight @ v
70
+
71
+ out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
72
+
73
+ return self.to_out(out)
74
+
75
+
76
+ class InfiniteYouImageProjector(nn.Module):
77
+
78
+ def __init__(
79
+ self,
80
+ dim=1280,
81
+ depth=4,
82
+ dim_head=64,
83
+ heads=20,
84
+ num_queries=8,
85
+ embedding_dim=512,
86
+ output_dim=4096,
87
+ ff_mult=4,
88
+ ):
89
+ super().__init__()
90
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
91
+ self.proj_in = nn.Linear(embedding_dim, dim)
92
+
93
+ self.proj_out = nn.Linear(dim, output_dim)
94
+ self.norm_out = nn.LayerNorm(output_dim)
95
+
96
+ self.layers = nn.ModuleList([])
97
+ for _ in range(depth):
98
+ self.layers.append(
99
+ nn.ModuleList([
100
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
101
+ FeedForward(dim=dim, mult=ff_mult),
102
+ ]))
103
+
104
+ def forward(self, x):
105
+
106
+ latents = self.latents.repeat(x.size(0), 1, 1)
107
+ latents = latents.to(dtype=x.dtype, device=x.device)
108
+
109
+ x = self.proj_in(x)
110
+
111
+ for attn, ff in self.layers:
112
+ latents = attn(x, latents) + latents
113
+ latents = ff(latents) + latents
114
+
115
+ latents = self.proj_out(latents)
116
+ return self.norm_out(latents)
117
+
118
+ @staticmethod
119
+ def state_dict_converter():
120
+ return FluxInfiniteYouImageProjectorStateDictConverter()
121
+
122
+
123
+ class FluxInfiniteYouImageProjectorStateDictConverter:
124
+
125
+ def __init__(self):
126
+ pass
127
+
128
+ def from_diffusers(self, state_dict):
129
+ return state_dict['image_proj']
diffsynth/models/flux_ipadapter.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .general_modules import RMSNorm
2
+ from transformers import SiglipVisionModel, SiglipVisionConfig
3
+ import torch
4
+
5
+
6
+ class SiglipVisionModelSO400M(SiglipVisionModel):
7
+ def __init__(self):
8
+ config = SiglipVisionConfig(
9
+ hidden_size=1152,
10
+ image_size=384,
11
+ intermediate_size=4304,
12
+ model_type="siglip_vision_model",
13
+ num_attention_heads=16,
14
+ num_hidden_layers=27,
15
+ patch_size=14,
16
+ architectures=["SiglipModel"],
17
+ initializer_factor=1.0,
18
+ torch_dtype="float32",
19
+ transformers_version="4.37.0.dev0"
20
+ )
21
+ super().__init__(config)
22
+
23
+ class MLPProjModel(torch.nn.Module):
24
+ def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4):
25
+ super().__init__()
26
+
27
+ self.cross_attention_dim = cross_attention_dim
28
+ self.num_tokens = num_tokens
29
+
30
+ self.proj = torch.nn.Sequential(
31
+ torch.nn.Linear(id_embeddings_dim, id_embeddings_dim*2),
32
+ torch.nn.GELU(),
33
+ torch.nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens),
34
+ )
35
+ self.norm = torch.nn.LayerNorm(cross_attention_dim)
36
+
37
+ def forward(self, id_embeds):
38
+ x = self.proj(id_embeds)
39
+ x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
40
+ x = self.norm(x)
41
+ return x
42
+
43
+ class IpAdapterModule(torch.nn.Module):
44
+ def __init__(self, num_attention_heads, attention_head_dim, input_dim):
45
+ super().__init__()
46
+ self.num_heads = num_attention_heads
47
+ self.head_dim = attention_head_dim
48
+ output_dim = num_attention_heads * attention_head_dim
49
+ self.to_k_ip = torch.nn.Linear(input_dim, output_dim, bias=False)
50
+ self.to_v_ip = torch.nn.Linear(input_dim, output_dim, bias=False)
51
+ self.norm_added_k = RMSNorm(attention_head_dim, eps=1e-5, elementwise_affine=False)
52
+
53
+
54
+ def forward(self, hidden_states):
55
+ batch_size = hidden_states.shape[0]
56
+ # ip_k
57
+ ip_k = self.to_k_ip(hidden_states)
58
+ ip_k = ip_k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
59
+ ip_k = self.norm_added_k(ip_k)
60
+ # ip_v
61
+ ip_v = self.to_v_ip(hidden_states)
62
+ ip_v = ip_v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
63
+ return ip_k, ip_v
64
+
65
+
66
+ class FluxIpAdapter(torch.nn.Module):
67
+ def __init__(self, num_attention_heads=24, attention_head_dim=128, cross_attention_dim=4096, num_tokens=128, num_blocks=57):
68
+ super().__init__()
69
+ self.ipadapter_modules = torch.nn.ModuleList([IpAdapterModule(num_attention_heads, attention_head_dim, cross_attention_dim) for _ in range(num_blocks)])
70
+ self.image_proj = MLPProjModel(cross_attention_dim=cross_attention_dim, id_embeddings_dim=1152, num_tokens=num_tokens)
71
+ self.set_adapter()
72
+
73
+ def set_adapter(self):
74
+ self.call_block_id = {i:i for i in range(len(self.ipadapter_modules))}
75
+
76
+ def forward(self, hidden_states, scale=1.0):
77
+ hidden_states = self.image_proj(hidden_states)
78
+ hidden_states = hidden_states.view(1, -1, hidden_states.shape[-1])
79
+ ip_kv_dict = {}
80
+ for block_id in self.call_block_id:
81
+ ipadapter_id = self.call_block_id[block_id]
82
+ ip_k, ip_v = self.ipadapter_modules[ipadapter_id](hidden_states)
83
+ ip_kv_dict[block_id] = {
84
+ "ip_k": ip_k,
85
+ "ip_v": ip_v,
86
+ "scale": scale
87
+ }
88
+ return ip_kv_dict
89
+
90
+ @staticmethod
91
+ def state_dict_converter():
92
+ return FluxIpAdapterStateDictConverter()
93
+
94
+
95
+ class FluxIpAdapterStateDictConverter:
96
+ def __init__(self):
97
+ pass
98
+
99
+ def from_diffusers(self, state_dict):
100
+ state_dict_ = {}
101
+ for name in state_dict["ip_adapter"]:
102
+ name_ = 'ipadapter_modules.' + name
103
+ state_dict_[name_] = state_dict["ip_adapter"][name]
104
+ for name in state_dict["image_proj"]:
105
+ name_ = "image_proj." + name
106
+ state_dict_[name_] = state_dict["image_proj"][name]
107
+ return state_dict_
108
+
109
+ def from_civitai(self, state_dict):
110
+ return self.from_diffusers(state_dict)